Commit 26497d11 authored by thomwolf's avatar thomwolf
Browse files

fix tests

parent 6a083fd4
......@@ -262,7 +262,7 @@ class TFCommonTestCases:
# self.assertEqual(len(params_tied_2), len(params_tied))
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32):
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
......@@ -275,7 +275,11 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32):
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return tf.constant(values, shape=shape, dtype=dtype)
output = tf.constant(values,
shape=shape,
dtype=dtype if dtype is not None else tf.int32)
return output
class TFModelUtilsTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment