Commit c1f1849f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #7706 from Vooblin:Vooblin-patch-1

PiperOrigin-RevId: 276182092
parents e0460db7 c34c439b
...@@ -341,11 +341,8 @@ class EmbeddingPostprocessor(tf.keras.layers.Layer): ...@@ -341,11 +341,8 @@ class EmbeddingPostprocessor(tf.keras.layers.Layer):
output = word_embeddings output = word_embeddings
if self.use_type_embeddings: if self.use_type_embeddings:
flat_token_type_ids = tf.reshape(token_type_ids, [-1]) flat_token_type_ids = tf.reshape(token_type_ids, [-1])
one_hot_ids = tf.one_hot( token_type_embeddings = tf.gather(self.type_embeddings,
flat_token_type_ids, flat_token_type_ids)
depth=self.token_type_vocab_size,
dtype=self.dtype)
token_type_embeddings = tf.matmul(one_hot_ids, self.type_embeddings)
token_type_embeddings = tf.reshape(token_type_embeddings, token_type_embeddings = tf.reshape(token_type_embeddings,
[batch_size, seq_length, width]) [batch_size, seq_length, width])
output += token_type_embeddings output += token_type_embeddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment