Commit 126ce652 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Do not access self.embeddings.dtype.

Soon, AutoCastVariable.dtype will refer to the variable dtype, not the compute
dtype. This change stops using AutoCastVariable.dtype.

PiperOrigin-RevId: 337670544
parent 19113a57
...@@ -77,8 +77,14 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -77,8 +77,14 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
flat_inputs = tf.reshape(inputs, [-1]) flat_inputs = tf.reshape(inputs, [-1])
if self._use_one_hot: if self._use_one_hot:
dtype = self._compute_dtype
if not tf.dtypes.as_dtype(dtype).is_floating:
# TensorFlow 1 compatibility. In TF1, self._compute_dtype is int32
# instead of a floating-point dtype, as the dtype is inferred from the
# dtype of the inputs
dtype = tf.float32
one_hot_data = tf.one_hot( one_hot_data = tf.one_hot(
flat_inputs, depth=self._vocab_size, dtype=self.embeddings.dtype) flat_inputs, depth=self._vocab_size, dtype=dtype)
embeddings = tf.matmul(one_hot_data, self.embeddings) embeddings = tf.matmul(one_hot_data, self.embeddings)
else: else:
embeddings = tf.gather(self.embeddings, flat_inputs) embeddings = tf.gather(self.embeddings, flat_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment