Unverified Commit c93ac621 authored by Yichao 'Peak' Ji's avatar Yichao 'Peak' Ji Committed by GitHub
Browse files

Fix output shape of the position embedding layer

parent 6d7030f2
......@@ -116,10 +116,10 @@ class PositionEmbedding(tf.keras.layers.Layer):
seq_length = input_shape[1]
width = input_shape[2]
position_embeddings = tf.expand_dims(
tf.slice(self._position_embeddings, [0, 0], [seq_length, width]),
axis=0)
position_embeddings = tf.slice(self._position_embeddings,
[0, 0],
[seq_length, width])
else:
position_embeddings = tf.expand_dims(self._position_embeddings, axis=0)
position_embeddings = self._position_embeddings
return position_embeddings
return tf.broadcast_to(position_embeddings, input_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment