Unverified Commit 935e3469 authored by Cola's avatar Cola Committed by GitHub
Browse files

🎨 Change nn.dropout to layer.Dropout (#9047)

parent b01ddc95
......@@ -235,9 +235,9 @@ class TFEncoderLayer(tf.keras.layers.Layer):
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.dropout = config.dropout
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
......@@ -261,7 +261,7 @@ class TFEncoderLayer(tf.keras.layers.Layer):
assert shape_list(x) == shape_list(
residual
), f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}"
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = self.dropout(x, training=training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
......@@ -270,9 +270,9 @@ class TFEncoderLayer(tf.keras.layers.Layer):
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = tf.nn.dropout(x, rate=self.activation_dropout if training else 0)
x = self.activation_dropout(x, training=training)
x = self.fc2(x)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = self.dropout(x, training=training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
......@@ -293,7 +293,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, embed_tokens: TFSharedEmbeddings, **kwargs):
super().__init__(**kwargs)
self.dropout = config.dropout
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.layerdrop = config.encoder_layerdrop
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
......@@ -370,7 +370,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos
x = self.layernorm_embedding(x)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = self.dropout(x, training=training)
# B x T x C -> T x B x C
x = tf.transpose(x, perm=[1, 0, 2])
......@@ -413,9 +413,9 @@ class TFDecoderLayer(tf.keras.layers.Layer):
dropout=config.attention_dropout,
name="self_attn",
)
self.dropout = config.dropout
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
......@@ -467,7 +467,7 @@ class TFDecoderLayer(tf.keras.layers.Layer):
attn_mask=causal_mask,
key_padding_mask=decoder_padding_mask,
)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = self.dropout(x, training=training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
......@@ -481,7 +481,7 @@ class TFDecoderLayer(tf.keras.layers.Layer):
key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = self.dropout(x, training=training)
x = residual + x
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
......@@ -490,9 +490,9 @@ class TFDecoderLayer(tf.keras.layers.Layer):
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = tf.nn.dropout(x, rate=self.activation_dropout if training else 0)
x = self.activation_dropout(x, training=training)
x = self.fc2(x)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = self.dropout(x, training=training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
......@@ -545,7 +545,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
else None
)
self.dropout = config.dropout
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.use_cache = config.use_cache
......@@ -588,7 +588,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
x = self.layernorm_embedding(x) + positions
else:
x = self.layernorm_embedding(x + positions)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = self.dropout(x, training=training)
# Convert to Bart output format: (BS, seq_len, model_dim) -> (seq_len, BS, model_dim)
x = tf.transpose(x, perm=(1, 0, 2))
......@@ -674,7 +674,7 @@ class TFAttention(tf.keras.layers.Layer):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.dropout = tf.keras.layers.Dropout(dropout)
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
......@@ -772,7 +772,7 @@ class TFAttention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout if training else 0.0)
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, v) # shape: (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = tf.transpose(attn_output, perm=(1, 0, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment