Commit 567bd18d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Explicit set the dtype in embedding layer. Otherwise, using the layer in TF1.x...

Explicit set the dtype in embedding layer. Otherwise, using the layer in TF1.x will create an int32 tensor.

PiperOrigin-RevId: 322473236
parent 36101ab4
......@@ -43,6 +43,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
self.shared_weights = self.add_weight(
"weights",
shape=[self.vocab_size, self.hidden_size],
dtype=tf.float32,
initializer=tf.random_normal_initializer(
mean=0., stddev=self.hidden_size**-0.5))
super(EmbeddingSharedWeights, self).build(input_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment