Commit c1f1849f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #7706 from Vooblin:Vooblin-patch-1

PiperOrigin-RevId: 276182092
parents e0460db7 c34c439b
......@@ -341,11 +341,8 @@ class EmbeddingPostprocessor(tf.keras.layers.Layer):
output = word_embeddings
if self.use_type_embeddings:
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
one_hot_ids = tf.one_hot(
flat_token_type_ids,
depth=self.token_type_vocab_size,
dtype=self.dtype)
token_type_embeddings = tf.matmul(one_hot_ids, self.type_embeddings)
token_type_embeddings = tf.gather(self.type_embeddings,
flat_token_type_ids)
token_type_embeddings = tf.reshape(token_type_embeddings,
[batch_size, seq_length, width])
output += token_type_embeddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment