Unverified Commit a2392415 authored by code-review-doctor's avatar code-review-doctor Committed by GitHub
Browse files

Some tests misusing assertTrue for comparisons fix (#16771)

* Fix issue avoid-misusing-assert-true found at https://codereview.doctor



* fix tests

* fix tf
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent d3bd9ac7
...@@ -299,6 +299,10 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ...@@ -299,6 +299,10 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
if output_word_offsets: if output_word_offsets:
word_offsets = self._get_word_offsets(char_offsets, self.replace_word_delimiter_char) word_offsets = self._get_word_offsets(char_offsets, self.replace_word_delimiter_char)
# don't output chars if not set to True
if not output_char_offsets:
char_offsets = None
# join to string # join to string
join_char = " " if spaces_between_special_tokens else "" join_char = " " if spaces_between_special_tokens else ""
string = join_char.join(processed_chars).strip() string = join_char.join(processed_chars).strip()
......
...@@ -416,11 +416,11 @@ class LongformerModelIntegrationTest(unittest.TestCase): ...@@ -416,11 +416,11 @@ class LongformerModelIntegrationTest(unittest.TestCase):
def test_pad_and_transpose_last_two_dims(self): def test_pad_and_transpose_last_two_dims(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
self.assertTrue(hidden_states.shape, (1, 8, 4)) self.assertEqual(hidden_states.shape, (1, 4, 8))
padding = (0, 0, 0, 1) padding = (0, 0, 0, 1)
padded_hidden_states = LongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, padding) padded_hidden_states = LongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, padding)
self.assertTrue(padded_hidden_states.shape, (1, 8, 5)) self.assertEqual(padded_hidden_states.shape, (1, 8, 5))
expected_added_dim = torch.zeros((5,), device=torch_device, dtype=torch.float32) expected_added_dim = torch.zeros((5,), device=torch_device, dtype=torch.float32)
self.assertTrue(torch.allclose(expected_added_dim, padded_hidden_states[0, -1, :], atol=1e-6)) self.assertTrue(torch.allclose(expected_added_dim, padded_hidden_states[0, -1, :], atol=1e-6))
...@@ -445,7 +445,7 @@ class LongformerModelIntegrationTest(unittest.TestCase): ...@@ -445,7 +445,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, atol=1e-3)) self.assertTrue(torch.allclose(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, atol=1e-3))
self.assertTrue(torch.allclose(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, atol=1e-3)) self.assertTrue(torch.allclose(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, atol=1e-3))
self.assertTrue(chunked_hidden_states.shape, (1, 3, 4, 4)) self.assertEqual(chunked_hidden_states.shape, (1, 3, 4, 4))
def test_mask_invalid_locations(self): def test_mask_invalid_locations(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
...@@ -493,7 +493,7 @@ class LongformerModelIntegrationTest(unittest.TestCase): ...@@ -493,7 +493,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
)[0] )[0]
self.assertTrue(output_hidden_states.shape, (1, 4, 8)) self.assertEqual(output_hidden_states.shape, (1, 4, 8))
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
output_hidden_states[0, 1], output_hidden_states[0, 1],
...@@ -531,7 +531,7 @@ class LongformerModelIntegrationTest(unittest.TestCase): ...@@ -531,7 +531,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
)[0] )[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8)) self.assertEqual(output_hidden_states.shape, (2, 4, 8))
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
......
...@@ -413,7 +413,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase): ...@@ -413,7 +413,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
def test_pad_and_transpose_last_two_dims(self): def test_pad_and_transpose_last_two_dims(self):
hidden_states = self._get_hidden_states() hidden_states = self._get_hidden_states()
self.assertTrue(shape_list(hidden_states), [1, 8, 4]) self.assertEqual(shape_list(hidden_states), [1, 4, 8])
# pad along seq length dim # pad along seq length dim
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
...@@ -486,7 +486,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase): ...@@ -486,7 +486,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32 [0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
) )
self.assertTrue(output_hidden_states.shape, (1, 4, 8)) self.assertEqual(output_hidden_states.shape, (1, 4, 8))
tf.debugging.assert_near(output_hidden_states[0, 1], expected_slice, rtol=1e-3) tf.debugging.assert_near(output_hidden_states[0, 1], expected_slice, rtol=1e-3)
def test_layer_global_attn(self): def test_layer_global_attn(self):
...@@ -523,7 +523,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase): ...@@ -523,7 +523,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
] ]
)[0] )[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8)) self.assertEqual(output_hidden_states.shape, (2, 4, 8))
expected_slice_0 = tf.convert_to_tensor( expected_slice_0 = tf.convert_to_tensor(
[-0.06508, -0.039306, 0.030934, -0.03417, -0.00656, -0.01553, -0.02088, -0.04938], dtype=tf.dtypes.float32 [-0.06508, -0.039306, 0.030934, -0.03417, -0.00656, -0.01553, -0.02088, -0.04938], dtype=tf.dtypes.float32
) )
......
...@@ -185,7 +185,7 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin): ...@@ -185,7 +185,7 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
expected_mult_pad_length = pad_max_length if pad_max_length % 10 == 0 else (pad_max_length // 10 + 1) * 10 expected_mult_pad_length = pad_max_length if pad_max_length % 10 == 0 else (pad_max_length // 10 + 1) * 10
self.assertTrue(all(len(x) == expected_mult_pad_length for x in input_8)) self.assertTrue(all(len(x) == expected_mult_pad_length for x in input_8))
self.assertTrue(input_9.shape[:2], (batch_size, expected_mult_pad_length)) self.assertEqual(input_9.shape[:2], (batch_size, expected_mult_pad_length))
if feature_size > 1: if feature_size > 1:
self.assertTrue(input_9.shape[2] == feature_size) self.assertTrue(input_9.shape[2] == feature_size)
......
...@@ -809,7 +809,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -809,7 +809,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True) trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True)
preds = trainer.predict(trainer.eval_dataset).predictions preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x x = trainer.eval_dataset.x
self.assertTrue(len(preds), 2) self.assertEqual(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5)) self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5)) self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
...@@ -819,7 +819,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -819,7 +819,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
preds = outputs.predictions preds = outputs.predictions
labels = outputs.label_ids labels = outputs.label_ids
x = trainer.eval_dataset.x x = trainer.eval_dataset.x
self.assertTrue(len(preds), 2) self.assertEqual(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5)) self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5)) self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
......
...@@ -97,9 +97,9 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -97,9 +97,9 @@ class TrainerUtilsTest(unittest.TestCase):
gatherer.add_arrays([predictions[indices], [predictions[indices], predictions[indices]]]) gatherer.add_arrays([predictions[indices], [predictions[indices], predictions[indices]]])
result = gatherer.finalize() result = gatherer.finalize()
self.assertTrue(isinstance(result, list)) self.assertTrue(isinstance(result, list))
self.assertTrue(len(result), 2) self.assertEqual(len(result), 2)
self.assertTrue(isinstance(result[1], list)) self.assertTrue(isinstance(result[1], list))
self.assertTrue(len(result[1]), 2) self.assertEqual(len(result[1]), 2)
self.assertTrue(np.array_equal(result[0], predictions)) self.assertTrue(np.array_equal(result[0], predictions))
self.assertTrue(np.array_equal(result[1][0], predictions)) self.assertTrue(np.array_equal(result[1][0], predictions))
self.assertTrue(np.array_equal(result[1][1], predictions)) self.assertTrue(np.array_equal(result[1][1], predictions))
......
...@@ -386,7 +386,7 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase): ...@@ -386,7 +386,7 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
# make sure that full vectors are sampled and not values of vectors # make sure that full vectors are sampled and not values of vectors
# => this means that `unique()` yields a single value for `hidden_size` dim # => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1)) self.assertEqual(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
def test_sample_negatives_with_attn_mask(self): def test_sample_negatives_with_attn_mask(self):
batch_size = 2 batch_size = 2
...@@ -428,7 +428,7 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase): ...@@ -428,7 +428,7 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
# make sure that full vectors are sampled and not just slices of vectors # make sure that full vectors are sampled and not just slices of vectors
# => this means that `unique()` yields a single value for `hidden_size` dim # => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1)) self.assertEqual(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_flax @require_flax
......
...@@ -1061,7 +1061,7 @@ class Wav2Vec2UtilsTest(unittest.TestCase): ...@@ -1061,7 +1061,7 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
self.assertTrue(((negative - features) == 0).sum() == 0.0) self.assertTrue(((negative - features) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1)) self.assertEqual(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
def test_sample_negatives_with_mask(self): def test_sample_negatives_with_mask(self):
batch_size = 2 batch_size = 2
...@@ -1098,7 +1098,7 @@ class Wav2Vec2UtilsTest(unittest.TestCase): ...@@ -1098,7 +1098,7 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
self.assertTrue(((negative - features) == 0).sum() == 0.0) self.assertTrue(((negative - features) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1)) self.assertEqual(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_torch @require_torch
......
...@@ -202,7 +202,7 @@ class Wav2Vec2TokenizerTest(unittest.TestCase): ...@@ -202,7 +202,7 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
input_values_5 = tokenizer(speech_inputs, padding="max_length", max_length=1600).input_values input_values_5 = tokenizer(speech_inputs, padding="max_length", max_length=1600).input_values
self.assertTrue(_input_values_are_equal(input_values_1, input_values_4)) self.assertTrue(_input_values_are_equal(input_values_1, input_values_4))
self.assertTrue(input_values_5.shape, (3, 1600)) self.assertEqual(input_values_5.shape, (3, 1600))
# padding should be 0.0 # padding should be 0.0
self.assertTrue(abs(sum(np.asarray(input_values_5[0])[800:1200])) < 1e-3) self.assertTrue(abs(sum(np.asarray(input_values_5[0])[800:1200])) < 1e-3)
...@@ -213,8 +213,8 @@ class Wav2Vec2TokenizerTest(unittest.TestCase): ...@@ -213,8 +213,8 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
).input_values ).input_values
self.assertTrue(_input_values_are_equal(input_values_1, input_values_6)) self.assertTrue(_input_values_are_equal(input_values_1, input_values_6))
self.assertTrue(input_values_7.shape, (3, 1500)) self.assertEqual(input_values_7.shape, (3, 1500))
self.assertTrue(input_values_8.shape, (3, 2500)) self.assertEqual(input_values_8.shape, (3, 2500))
# padding should be 0.0 # padding should be 0.0
self.assertTrue(abs(sum(np.asarray(input_values_7[0])[800:])) < 1e-3) self.assertTrue(abs(sum(np.asarray(input_values_7[0])[800:])) < 1e-3)
self.assertTrue(abs(sum(np.asarray(input_values_7[1])[1000:])) < 1e-3) self.assertTrue(abs(sum(np.asarray(input_values_7[1])[1000:])) < 1e-3)
...@@ -489,21 +489,21 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -489,21 +489,21 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
outputs_char = tokenizer.decode(sample_ids, output_char_offsets=True) outputs_char = tokenizer.decode(sample_ids, output_char_offsets=True)
# check Wav2Vec2CTCTokenizerOutput keys for char # check Wav2Vec2CTCTokenizerOutput keys for char
self.assertTrue(len(outputs_char.keys()), 2) self.assertEqual(len(outputs_char.keys()), 2)
self.assertTrue("text" in outputs_char) self.assertTrue("text" in outputs_char)
self.assertTrue("char_offsets" in outputs_char) self.assertTrue("char_offsets" in outputs_char)
self.assertTrue(isinstance(outputs_char, Wav2Vec2CTCTokenizerOutput)) self.assertTrue(isinstance(outputs_char, Wav2Vec2CTCTokenizerOutput))
outputs_word = tokenizer.decode(sample_ids, output_word_offsets=True) outputs_word = tokenizer.decode(sample_ids, output_word_offsets=True)
# check Wav2Vec2CTCTokenizerOutput keys for word # check Wav2Vec2CTCTokenizerOutput keys for word
self.assertTrue(len(outputs_word.keys()), 2) self.assertEqual(len(outputs_word.keys()), 2)
self.assertTrue("text" in outputs_word) self.assertTrue("text" in outputs_word)
self.assertTrue("word_offsets" in outputs_word) self.assertTrue("word_offsets" in outputs_word)
self.assertTrue(isinstance(outputs_word, Wav2Vec2CTCTokenizerOutput)) self.assertTrue(isinstance(outputs_word, Wav2Vec2CTCTokenizerOutput))
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, output_word_offsets=True) outputs = tokenizer.decode(sample_ids, output_char_offsets=True, output_word_offsets=True)
# check Wav2Vec2CTCTokenizerOutput keys for both # check Wav2Vec2CTCTokenizerOutput keys for both
self.assertTrue(len(outputs.keys()), 3) self.assertEqual(len(outputs.keys()), 3)
self.assertTrue("text" in outputs) self.assertTrue("text" in outputs)
self.assertTrue("char_offsets" in outputs) self.assertTrue("char_offsets" in outputs)
self.assertTrue("word_offsets" in outputs) self.assertTrue("word_offsets" in outputs)
......
...@@ -265,7 +265,7 @@ class Wav2Vec2PhonemeCTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -265,7 +265,7 @@ class Wav2Vec2PhonemeCTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, filter_word_delimiter_token=False) outputs = tokenizer.decode(sample_ids, output_char_offsets=True, filter_word_delimiter_token=False)
# check Wav2Vec2CTCTokenizerOutput keys for char # check Wav2Vec2CTCTokenizerOutput keys for char
self.assertTrue(len(outputs.keys()), 2) self.assertEqual(len(outputs.keys()), 2)
self.assertTrue("text" in outputs) self.assertTrue("text" in outputs)
self.assertTrue("char_offsets" in outputs) self.assertTrue("char_offsets" in outputs)
self.assertTrue(isinstance(outputs, Wav2Vec2PhonemeCTCTokenizerOutput)) self.assertTrue(isinstance(outputs, Wav2Vec2PhonemeCTCTokenizerOutput))
......
...@@ -368,7 +368,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -368,7 +368,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
outputs = processor.decode(logits, output_word_offsets=True) outputs = processor.decode(logits, output_word_offsets=True)
# check Wav2Vec2CTCTokenizerOutput keys for word # check Wav2Vec2CTCTokenizerOutput keys for word
self.assertTrue(len(outputs.keys()), 2) self.assertEqual(len(outputs.keys()), 4)
self.assertTrue("text" in outputs) self.assertTrue("text" in outputs)
self.assertTrue("word_offsets" in outputs) self.assertTrue("word_offsets" in outputs)
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput)) self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
...@@ -385,7 +385,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -385,7 +385,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
outputs = processor.batch_decode(logits, output_word_offsets=True) outputs = processor.batch_decode(logits, output_word_offsets=True)
# check Wav2Vec2CTCTokenizerOutput keys for word # check Wav2Vec2CTCTokenizerOutput keys for word
self.assertTrue(len(outputs.keys()), 2) self.assertEqual(len(outputs.keys()), 4)
self.assertTrue("text" in outputs) self.assertTrue("text" in outputs)
self.assertTrue("word_offsets" in outputs) self.assertTrue("word_offsets" in outputs)
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput)) self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment