"vscode:/vscode.git/clone" did not exist on "9b0873c712348ecabdc33031e64a4975a657c715"
Commit 26497d11 authored by thomwolf's avatar thomwolf
Browse files

fix tests

parent 6a083fd4
...@@ -262,7 +262,7 @@ class TFCommonTestCases: ...@@ -262,7 +262,7 @@ class TFCommonTestCases:
# self.assertEqual(len(params_tied_2), len(params_tied)) # self.assertEqual(len(params_tied_2), len(params_tied))
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32): def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
if rng is None: if rng is None:
rng = random.Random() rng = random.Random()
...@@ -275,7 +275,11 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32): ...@@ -275,7 +275,11 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32):
for _ in range(total_dims): for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1)) values.append(rng.randint(0, vocab_size - 1))
return tf.constant(values, shape=shape, dtype=dtype) output = tf.constant(values,
shape=shape,
dtype=dtype if dtype is not None else tf.int32)
return output
class TFModelUtilsTest(unittest.TestCase): class TFModelUtilsTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment