"vscode:/vscode.git/clone" did not exist on "b9911dcb2f48273479871d018823e11d03642d92"
Unverified Commit c93ac621 authored by Yichao 'Peak' Ji's avatar Yichao 'Peak' Ji Committed by GitHub
Browse files

Fix output shape of the position embedding layer

parent 6d7030f2
......@@ -116,10 +116,10 @@ class PositionEmbedding(tf.keras.layers.Layer):
seq_length = input_shape[1]
width = input_shape[2]
position_embeddings = tf.expand_dims(
tf.slice(self._position_embeddings, [0, 0], [seq_length, width]),
axis=0)
position_embeddings = tf.slice(self._position_embeddings,
[0, 0],
[seq_length, width])
else:
position_embeddings = tf.expand_dims(self._position_embeddings, axis=0)
position_embeddings = self._position_embeddings
return position_embeddings
return tf.broadcast_to(position_embeddings, input_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment