Unverified Commit f1679d7c authored by Zhu Baohe's avatar Zhu Baohe Committed by GitHub
Browse files

Fix dropout in TFMobileBert (#5150)

parent 5ed94b23
...@@ -370,7 +370,7 @@ class TFMobileBertOutput(tf.keras.layers.Layer): ...@@ -370,7 +370,7 @@ class TFMobileBertOutput(tf.keras.layers.Layer):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
if not self.use_bottleneck: if not self.use_bottleneck:
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1) hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
else: else:
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1) hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment