Unverified Commit 248d5d23 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Tests: replace `torch.testing.assert_allclose` by `torch.testing.assert_close` (#29915)

* replace torch.testing.assert_allclose by torch.testing.assert_close

* missing atol rtol
parent 7c19fafe
...@@ -559,7 +559,7 @@ class FNetModelIntegrationTest(unittest.TestCase): ...@@ -559,7 +559,7 @@ class FNetModelIntegrationTest(unittest.TestCase):
max_length=512, max_length=512,
) )
torch.testing.assert_allclose(inputs["input_ids"], torch.tensor([[4, 13, 283, 2479, 106, 8, 6, 845, 5, 168, 65, 367, 6, 845, 5, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3]])) # fmt: skip torch.testing.assert_close(inputs["input_ids"], torch.tensor([[4, 13, 283, 2479, 106, 8, 6, 845, 5, 168, 65, 367, 6, 845, 5, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3]])) # fmt: skip
inputs = {k: v.to(torch_device) for k, v in inputs.items()} inputs = {k: v.to(torch_device) for k, v in inputs.items()}
......
...@@ -189,13 +189,13 @@ class Jukebox1bModelTester(unittest.TestCase): ...@@ -189,13 +189,13 @@ class Jukebox1bModelTester(unittest.TestCase):
self.assertListEqual(metadata.numpy()[0][:10].tolist(), self.EXPECTED_Y_COND) self.assertListEqual(metadata.numpy()[0][:10].tolist(), self.EXPECTED_Y_COND)
audio_conditioning, metadata_conditioning, lyric_tokens = top_prior.get_cond(music_token_conds, metadata) audio_conditioning, metadata_conditioning, lyric_tokens = top_prior.get_cond(music_token_conds, metadata)
torch.testing.assert_allclose( torch.testing.assert_close(
audio_conditioning[0][0][:30].detach(), torch.tensor(self.EXPECTED_AUDIO_COND), atol=1e-4, rtol=1e-4 audio_conditioning[0][0][:30].detach(), torch.tensor(self.EXPECTED_AUDIO_COND), atol=1e-4, rtol=1e-4
) )
torch.testing.assert_allclose( torch.testing.assert_close(
metadata_conditioning[0][0][:30].detach(), torch.tensor(self.EXPECTED_META_COND), atol=1e-4, rtol=1e-4 metadata_conditioning[0][0][:30].detach(), torch.tensor(self.EXPECTED_META_COND), atol=1e-4, rtol=1e-4
) )
torch.testing.assert_allclose( torch.testing.assert_close(
lyric_tokens[0, :30].detach(), torch.tensor(self.EXPECTED_LYRIC_COND), atol=1e-4, rtol=1e-4 lyric_tokens[0, :30].detach(), torch.tensor(self.EXPECTED_LYRIC_COND), atol=1e-4, rtol=1e-4
) )
...@@ -213,21 +213,21 @@ class Jukebox1bModelTester(unittest.TestCase): ...@@ -213,21 +213,21 @@ class Jukebox1bModelTester(unittest.TestCase):
zs = model._sample( zs = model._sample(
zs, tokens, sample_levels=[0], save_results=False, sample_length=40 * model.priors[0].raw_to_tokens zs, tokens, sample_levels=[0], save_results=False, sample_length=40 * model.priors[0].raw_to_tokens
) )
torch.testing.assert_allclose(zs[0][0][:40], torch.tensor(self.EXPECTED_PRIMED_0)) torch.testing.assert_close(zs[0][0][:40], torch.tensor(self.EXPECTED_PRIMED_0))
upper_2 = torch.cat((zs[0], torch.zeros(1, 2048 - zs[0].shape[-1])), dim=-1).long() upper_2 = torch.cat((zs[0], torch.zeros(1, 2048 - zs[0].shape[-1])), dim=-1).long()
zs = [upper_2, model.vqvae.encode(waveform, start_level=1, bs_chunks=waveform.shape[0])[0], None] zs = [upper_2, model.vqvae.encode(waveform, start_level=1, bs_chunks=waveform.shape[0])[0], None]
zs = model._sample( zs = model._sample(
zs, tokens, sample_levels=[1], save_results=False, sample_length=40 * model.priors[1].raw_to_tokens zs, tokens, sample_levels=[1], save_results=False, sample_length=40 * model.priors[1].raw_to_tokens
) )
torch.testing.assert_allclose(zs[1][0][:40], torch.tensor(self.EXPECTED_PRIMED_1)) torch.testing.assert_close(zs[1][0][:40], torch.tensor(self.EXPECTED_PRIMED_1))
upper_1 = torch.cat((zs[1], torch.zeros(1, 2048 - zs[1].shape[-1])), dim=-1).long() upper_1 = torch.cat((zs[1], torch.zeros(1, 2048 - zs[1].shape[-1])), dim=-1).long()
zs = [upper_2, upper_1, model.vqvae.encode(waveform, start_level=0, bs_chunks=waveform.shape[0])[0]] zs = [upper_2, upper_1, model.vqvae.encode(waveform, start_level=0, bs_chunks=waveform.shape[0])[0]]
zs = model._sample( zs = model._sample(
zs, tokens, sample_levels=[2], save_results=False, sample_length=40 * model.priors[2].raw_to_tokens zs, tokens, sample_levels=[2], save_results=False, sample_length=40 * model.priors[2].raw_to_tokens
) )
torch.testing.assert_allclose(zs[2][0][:40].cpu(), torch.tensor(self.EXPECTED_PRIMED_2)) torch.testing.assert_close(zs[2][0][:40].cpu(), torch.tensor(self.EXPECTED_PRIMED_2))
@slow @slow
def test_vqvae(self): def test_vqvae(self):
...@@ -236,11 +236,11 @@ class Jukebox1bModelTester(unittest.TestCase): ...@@ -236,11 +236,11 @@ class Jukebox1bModelTester(unittest.TestCase):
x = torch.rand((1, 5120, 1)) x = torch.rand((1, 5120, 1))
with torch.no_grad(): with torch.no_grad():
zs = model.vqvae.encode(x, start_level=2, bs_chunks=x.shape[0]) zs = model.vqvae.encode(x, start_level=2, bs_chunks=x.shape[0])
torch.testing.assert_allclose(zs[0][0], torch.tensor(self.EXPECTED_VQVAE_ENCODE)) torch.testing.assert_close(zs[0][0], torch.tensor(self.EXPECTED_VQVAE_ENCODE))
with torch.no_grad(): with torch.no_grad():
x = model.vqvae.decode(zs, start_level=2, bs_chunks=x.shape[0]) x = model.vqvae.decode(zs, start_level=2, bs_chunks=x.shape[0])
torch.testing.assert_allclose(x[0, :40, 0], torch.tensor(self.EXPECTED_VQVAE_DECODE), atol=1e-4, rtol=1e-4) torch.testing.assert_close(x[0, :40, 0], torch.tensor(self.EXPECTED_VQVAE_DECODE), atol=1e-4, rtol=1e-4)
@require_torch @require_torch
...@@ -379,19 +379,19 @@ class Jukebox5bModelTester(unittest.TestCase): ...@@ -379,19 +379,19 @@ class Jukebox5bModelTester(unittest.TestCase):
model.priors[0].to(torch_device) model.priors[0].to(torch_device)
zs = [torch.zeros(1, 0, dtype=torch.long).to(torch_device) for _ in range(3)] zs = [torch.zeros(1, 0, dtype=torch.long).to(torch_device) for _ in range(3)]
zs = model._sample(zs, labels, [0], sample_length=60 * model.priors[0].raw_to_tokens, save_results=False) zs = model._sample(zs, labels, [0], sample_length=60 * model.priors[0].raw_to_tokens, save_results=False)
torch.testing.assert_allclose(zs[0][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_2)) torch.testing.assert_close(zs[0][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_2))
model.priors[0].cpu() model.priors[0].cpu()
set_seed(0) set_seed(0)
model.priors[1].to(torch_device) model.priors[1].to(torch_device)
zs = model._sample(zs, labels, [1], sample_length=60 * model.priors[1].raw_to_tokens, save_results=False) zs = model._sample(zs, labels, [1], sample_length=60 * model.priors[1].raw_to_tokens, save_results=False)
torch.testing.assert_allclose(zs[1][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_1)) torch.testing.assert_close(zs[1][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_1))
model.priors[1].cpu() model.priors[1].cpu()
set_seed(0) set_seed(0)
model.priors[2].to(torch_device) model.priors[2].to(torch_device)
zs = model._sample(zs, labels, [2], sample_length=60 * model.priors[2].raw_to_tokens, save_results=False) zs = model._sample(zs, labels, [2], sample_length=60 * model.priors[2].raw_to_tokens, save_results=False)
torch.testing.assert_allclose(zs[2][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_0)) torch.testing.assert_close(zs[2][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_0))
@slow @slow
@require_torch_accelerator @require_torch_accelerator
......
...@@ -387,12 +387,10 @@ class NllbMoeModelIntegrationTests(unittest.TestCase): ...@@ -387,12 +387,10 @@ class NllbMoeModelIntegrationTests(unittest.TestCase):
EXPECTED_DECODER_STATE = torch.Tensor([-6.0425e-02, -2.0015e-01, 6.0575e-02, -8.6366e-01, -1.1310e+00, 6.8369e-01, 7.5615e-01, 7.3555e-01, 2.3071e-01, 1.5954e+00, -7.0728e-01, -2.2647e-01, -1.3292e+00, 4.8246e-01, -6.9153e-01, -1.8199e-02, -7.3664e-01, 1.5902e-03, 1.0760e-01, 1.0298e-01, -9.3933e-01, -4.6567e-01, 8.0417e-01, 1.5243e+00, 5.5844e-01, -9.9239e-02, 1.4885e+00, 7.1527e-02, -5.2612e-01, 9.4435e-02]) EXPECTED_DECODER_STATE = torch.Tensor([-6.0425e-02, -2.0015e-01, 6.0575e-02, -8.6366e-01, -1.1310e+00, 6.8369e-01, 7.5615e-01, 7.3555e-01, 2.3071e-01, 1.5954e+00, -7.0728e-01, -2.2647e-01, -1.3292e+00, 4.8246e-01, -6.9153e-01, -1.8199e-02, -7.3664e-01, 1.5902e-03, 1.0760e-01, 1.0298e-01, -9.3933e-01, -4.6567e-01, 8.0417e-01, 1.5243e+00, 5.5844e-01, -9.9239e-02, 1.4885e+00, 7.1527e-02, -5.2612e-01, 9.4435e-02])
# fmt: on # fmt: on
torch.testing.assert_allclose( torch.testing.assert_close(
output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3 output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3
) )
torch.testing.assert_allclose( torch.testing.assert_close(output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3)
output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3
)
def test_inference_logits(self): def test_inference_logits(self):
r""" r"""
...@@ -405,7 +403,7 @@ class NllbMoeModelIntegrationTests(unittest.TestCase): ...@@ -405,7 +403,7 @@ class NllbMoeModelIntegrationTests(unittest.TestCase):
output = model(**self.model_inputs) output = model(**self.model_inputs)
EXPECTED_LOGTIS = torch.Tensor([-0.3059, 0.0000, 9.3029, 0.6456, -0.9148, 1.7836, 0.6478, 0.9438, -0.5272, -0.6617, -1.2717, 0.4564, 0.1345, -0.2301, -1.0140, 1.1427, -1.5535, 0.1337, 0.2082, -0.8112, -0.3842, -0.3377, 0.1256, 0.6450, -0.0452, 0.0219, 1.4274, -0.4991, -0.2063, -0.4409,]) # fmt: skip EXPECTED_LOGTIS = torch.Tensor([-0.3059, 0.0000, 9.3029, 0.6456, -0.9148, 1.7836, 0.6478, 0.9438, -0.5272, -0.6617, -1.2717, 0.4564, 0.1345, -0.2301, -1.0140, 1.1427, -1.5535, 0.1337, 0.2082, -0.8112, -0.3842, -0.3377, 0.1256, 0.6450, -0.0452, 0.0219, 1.4274, -0.4991, -0.2063, -0.4409,]) # fmt: skip
torch.testing.assert_allclose(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3) torch.testing.assert_close(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)
@unittest.skip("This requires 300GB of RAM") @unittest.skip("This requires 300GB of RAM")
def test_large_logits(self): def test_large_logits(self):
...@@ -419,13 +417,11 @@ class NllbMoeModelIntegrationTests(unittest.TestCase): ...@@ -419,13 +417,11 @@ class NllbMoeModelIntegrationTests(unittest.TestCase):
EXPECTED_LOGTIS = torch.Tensor([ 0.3834, 0.2057, 4.5399, 0.8301, 0.4810, 0.9325, 0.9928, 0.9574, 0.5517, 0.9156, 0.2698, 0.6728, 0.7121, 0.3080, 0.4693, 0.5756, 1.0407, 0.2219, 0.3714, 0.5699, 0.5547, 0.8472, 0.3178, 0.1286, 0.1791, 0.9391, 0.5153, -0.2146, 0.1689, 0.6816]) EXPECTED_LOGTIS = torch.Tensor([ 0.3834, 0.2057, 4.5399, 0.8301, 0.4810, 0.9325, 0.9928, 0.9574, 0.5517, 0.9156, 0.2698, 0.6728, 0.7121, 0.3080, 0.4693, 0.5756, 1.0407, 0.2219, 0.3714, 0.5699, 0.5547, 0.8472, 0.3178, 0.1286, 0.1791, 0.9391, 0.5153, -0.2146, 0.1689, 0.6816])
# fmt: on # fmt: on
torch.testing.assert_allclose( torch.testing.assert_close(
output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3 output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3
) )
torch.testing.assert_allclose( torch.testing.assert_close(output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3)
output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3 torch.testing.assert_close(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)
)
torch.testing.assert_allclose(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)
@unittest.skip("This requires 300GB of RAM") @unittest.skip("This requires 300GB of RAM")
def test_seq_to_seq_generation(self): def test_seq_to_seq_generation(self):
...@@ -564,10 +560,10 @@ class NllbMoeRouterTest(unittest.TestCase): ...@@ -564,10 +560,10 @@ class NllbMoeRouterTest(unittest.TestCase):
# `sampling` and `random` do not affect the mask of the top_1 router # `sampling` and `random` do not affect the mask of the top_1 router
# fmt: on # fmt: on
torch.testing.assert_allclose(router_probs_all, EXPECTED_ROUTER_ALL, 1e-4, 1e-4) torch.testing.assert_close(router_probs_all, EXPECTED_ROUTER_ALL, rtol=1e-4, atol=1e-4)
torch.testing.assert_allclose(router_probs_sp, EXPECTED_ROUTER_SP, 1e-4, 1e-4) torch.testing.assert_close(router_probs_sp, EXPECTED_ROUTER_SP, rtol=1e-4, atol=1e-4)
torch.testing.assert_allclose(router_probs, EXPECTED_ROUTER, 1e-4, 1e-4) torch.testing.assert_close(router_probs, EXPECTED_ROUTER, rtol=1e-4, atol=1e-4)
torch.testing.assert_allclose(top_1_mask_all, EXPECTED_TOP_1_ALL, 1e-4, 1e-4) torch.testing.assert_close(top_1_mask_all, EXPECTED_TOP_1_ALL, rtol=1e-4, atol=1e-4)
torch.testing.assert_allclose(top_1_mask_sp, EXPECTED_TOP_1_SP, 1e-4, 1e-4) torch.testing.assert_close(top_1_mask_sp, EXPECTED_TOP_1_SP, rtol=1e-4, atol=1e-4)
torch.testing.assert_allclose(top_1_mask, EXPECTED_TOP_1_SP, 1e-4, 1e-4) torch.testing.assert_close(top_1_mask, EXPECTED_TOP_1_SP, rtol=1e-4, atol=1e-4)
...@@ -725,7 +725,7 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -725,7 +725,7 @@ class SamModelIntegrationTest(unittest.TestCase):
iou_scores = outputs.iou_scores.cpu() iou_scores = outputs.iou_scores.cpu()
self.assertTrue(iou_scores.shape == (1, 2, 3)) self.assertTrue(iou_scores.shape == (1, 2, 3))
torch.testing.assert_allclose( torch.testing.assert_close(
iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4 iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4
) )
...@@ -753,7 +753,7 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -753,7 +753,7 @@ class SamModelIntegrationTest(unittest.TestCase):
iou_scores = outputs.iou_scores.cpu() iou_scores = outputs.iou_scores.cpu()
self.assertTrue(iou_scores.shape == (1, 3, 3)) self.assertTrue(iou_scores.shape == (1, 3, 3))
torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) torch.testing.assert_close(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
def test_dummy_pipeline_generation(self): def test_dummy_pipeline_generation(self):
generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device) generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)
......
...@@ -1053,7 +1053,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase): ...@@ -1053,7 +1053,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu() hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu()
hf_logits = hf_logits[0, 0, :30] hf_logits = hf_logits[0, 0, :30]
torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3) torch.testing.assert_close(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
@unittest.skip( @unittest.skip(
"Unless we stop stripping left and right by default for all special tokens, the expected ids obtained here will not match the original ones. Wait for https://github.com/huggingface/transformers/pull/23909 to be merged" "Unless we stop stripping left and right by default for all special tokens, the expected ids obtained here will not match the original ones. Wait for https://github.com/huggingface/transformers/pull/23909 to be merged"
......
...@@ -763,7 +763,7 @@ class Umt5IntegrationTest(unittest.TestCase): ...@@ -763,7 +763,7 @@ class Umt5IntegrationTest(unittest.TestCase):
] ]
) )
# fmt: on # fmt: on
torch.testing.assert_allclose(input_ids, EXPECTED_IDS) torch.testing.assert_close(input_ids, EXPECTED_IDS)
generated_ids = model.generate(input_ids.to(torch_device)) generated_ids = model.generate(input_ids.to(torch_device))
EXPECTED_FILLING = [ EXPECTED_FILLING = [
......
...@@ -462,10 +462,10 @@ class ModelTesterMixin: ...@@ -462,10 +462,10 @@ class ModelTesterMixin:
([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475]) ([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475])
) )
init_instance = MyClass() init_instance = MyClass()
torch.testing.assert_allclose(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4) torch.testing.assert_close(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4)
set_seed(0) set_seed(0)
torch.testing.assert_allclose( torch.testing.assert_close(
init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5)) init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5))
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment