Unverified Commit 74297d0a authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[Switch Transformers] Fix failing slow test (#20346)

* run slow test on GPU

* remove unnecessary device assignment

* use `torch_device` instead
parent 11f3ec72
......@@ -19,7 +19,7 @@ import tempfile
import unittest
from transformers import SwitchTransformersConfig, is_torch_available
from transformers.testing_utils import require_tokenizers, require_torch, slow, torch_device
from transformers.testing_utils import require_tokenizers, require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
......@@ -1104,15 +1104,18 @@ class SwitchTransformerRouterTest(unittest.TestCase):
@require_torch
@require_tokenizers
class SwitchTransformerModelIntegrationTests(unittest.TestCase):
@require_torch_gpu
def test_small_logits(self):
r"""
Logits testing to check implementation consistency between `t5x` implementation
and `transformers` implementation of Switch-C transformers. We only check the logits
of the first batch.
"""
model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).eval()
input_ids = torch.ones((32, 64), dtype=torch.long)
decoder_input_ids = torch.ones((32, 64), dtype=torch.long)
model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).to(
torch_device
)
input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device)
decoder_input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device)
# fmt: off
EXPECTED_MEAN_LOGITS = torch.Tensor(
......@@ -1126,8 +1129,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
]
).to(torch.bfloat16)
# fmt: on
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu()
hf_logits = hf_logits[0, 0, :30]
torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment