Unverified Commit 2f463eff authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TFDebertaV2ConvLayer in TFDebertaV2Model (#16031)



* fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 1da84ae0
......@@ -313,7 +313,7 @@ class TFDebertaV2ConvLayer(tf.keras.layers.Layer):
rmask = tf.cast(1 - input_mask, tf.bool)
out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out)
out = self.dropout(out, training=training)
hidden_states = self.conv_act(out)
out = self.conv_act(out)
layer_norm_input = residual_states + out
output = self.LayerNorm(layer_norm_input)
......@@ -323,10 +323,10 @@ class TFDebertaV2ConvLayer(tf.keras.layers.Layer):
else:
if len(shape_list(input_mask)) != len(shape_list(layer_norm_input)):
if len(shape_list(input_mask)) == 4:
mask = tf.squeeze(tf.squeeze(input_mask, axis=1), axis=1)
mask = tf.cast(tf.expand_dims(input_mask, axis=2), tf.float32)
input_mask = tf.squeeze(tf.squeeze(input_mask, axis=1), axis=1)
input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), tf.float32)
output_states = output * mask
output_states = output * input_mask
return output_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment