Unverified Commit bd64315e authored by Yichao 'Peak' Ji's avatar Yichao 'Peak' Ji Committed by GitHub
Browse files

Minor code style updates

parent c93ac621
...@@ -111,14 +111,9 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -111,14 +111,9 @@ class PositionEmbedding(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
if self._use_dynamic_slicing: if self._use_dynamic_slicing:
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3) position_embeddings = self._position_embeddings[:input_shape[1], :]
seq_length = input_shape[1]
width = input_shape[2]
position_embeddings = tf.slice(self._position_embeddings,
[0, 0],
[seq_length, width])
else: else:
position_embeddings = self._position_embeddings position_embeddings = self._position_embeddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment