Unverified Commit 2f463eff authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TFDebertaV2ConvLayer in TFDebertaV2Model (#16031)



* fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 1da84ae0
...@@ -313,7 +313,7 @@ class TFDebertaV2ConvLayer(tf.keras.layers.Layer): ...@@ -313,7 +313,7 @@ class TFDebertaV2ConvLayer(tf.keras.layers.Layer):
rmask = tf.cast(1 - input_mask, tf.bool) rmask = tf.cast(1 - input_mask, tf.bool)
out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out) out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out)
out = self.dropout(out, training=training) out = self.dropout(out, training=training)
hidden_states = self.conv_act(out) out = self.conv_act(out)
layer_norm_input = residual_states + out layer_norm_input = residual_states + out
output = self.LayerNorm(layer_norm_input) output = self.LayerNorm(layer_norm_input)
...@@ -323,10 +323,10 @@ class TFDebertaV2ConvLayer(tf.keras.layers.Layer): ...@@ -323,10 +323,10 @@ class TFDebertaV2ConvLayer(tf.keras.layers.Layer):
else: else:
if len(shape_list(input_mask)) != len(shape_list(layer_norm_input)): if len(shape_list(input_mask)) != len(shape_list(layer_norm_input)):
if len(shape_list(input_mask)) == 4: if len(shape_list(input_mask)) == 4:
mask = tf.squeeze(tf.squeeze(input_mask, axis=1), axis=1) input_mask = tf.squeeze(tf.squeeze(input_mask, axis=1), axis=1)
mask = tf.cast(tf.expand_dims(input_mask, axis=2), tf.float32) input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), tf.float32)
output_states = output * mask output_states = output * input_mask
return output_states return output_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment