Unverified Commit ac5bcf23 authored by Jared T Nielsen's avatar Jared T Nielsen Committed by GitHub
Browse files

Fix FFN dropout in TFAlbertLayer, and split dropout in TFAlbertAttent… (#4323)

* Fix FFN dropout in TFAlbertLayer, and split dropout in TFAlbertAttention into two separate dropout layers.

* Same dropout fixes for PyTorch.
parent 4ffea5ce
......@@ -212,7 +212,8 @@ class AlbertAttention(BertSelfAttention):
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pruned_heads = set()
......@@ -256,7 +257,7 @@ class AlbertAttention(BertSelfAttention):
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
attention_probs = self.attention_dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
......@@ -275,7 +276,7 @@ class AlbertAttention(BertSelfAttention):
b = self.dense.bias.to(context_layer.dtype)
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
projected_context_layer_dropout = self.dropout(projected_context_layer)
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer_dropout)
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
......@@ -290,6 +291,7 @@ class AlbertLayer(nn.Module):
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
self.activation = ACT2FN[config.hidden_act]
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
......@@ -298,6 +300,7 @@ class AlbertLayer(nn.Module):
ffn_output = self.ffn(attention_output[0])
ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output)
ffn_output = self.dropout(ffn_output)
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
return (hidden_states,) + attention_output[1:] # add attentions if we output them
......
......@@ -274,6 +274,8 @@ class TFAlbertSelfOutput(tf.keras.layers.Layer):
class TFAlbertAttention(TFBertSelfAttention):
""" Contains the complete attention sublayer, including both dropouts and layer norm. """
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
......@@ -284,6 +286,9 @@ class TFAlbertAttention(TFBertSelfAttention):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.pruned_heads = set()
# Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993
self.attention_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def prune_heads(self, heads):
raise NotImplementedError
......@@ -314,7 +319,7 @@ class TFAlbertAttention(TFBertSelfAttention):
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs, training=training)
attention_probs = self.attention_dropout(attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
......@@ -332,7 +337,7 @@ class TFAlbertAttention(TFBertSelfAttention):
hidden_states = self_outputs[0]
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.output_dropout(hidden_states, training=training)
attention_output = self.LayerNorm(hidden_states + input_tensor)
# add attentions if we output them
......@@ -369,8 +374,8 @@ class TFAlbertLayer(tf.keras.layers.Layer):
ffn_output = self.ffn(attention_outputs[0])
ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output)
ffn_output = self.dropout(ffn_output, training=training)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.full_layer_layer_norm(ffn_output + attention_outputs[0])
# add attentions if we output them
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment