Unverified Commit 74297d0a authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[Switch Transformers] Fix failing slow test (#20346)

* run slow test on GPU

* remove unnecessary device assignment

* use `torch_device` instead
parent 11f3ec72
...@@ -19,7 +19,7 @@ import tempfile ...@@ -19,7 +19,7 @@ import tempfile
import unittest import unittest
from transformers import SwitchTransformersConfig, is_torch_available from transformers import SwitchTransformersConfig, is_torch_available
from transformers.testing_utils import require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_tokenizers, require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -1104,15 +1104,18 @@ class SwitchTransformerRouterTest(unittest.TestCase): ...@@ -1104,15 +1104,18 @@ class SwitchTransformerRouterTest(unittest.TestCase):
@require_torch @require_torch
@require_tokenizers @require_tokenizers
class SwitchTransformerModelIntegrationTests(unittest.TestCase): class SwitchTransformerModelIntegrationTests(unittest.TestCase):
@require_torch_gpu
def test_small_logits(self): def test_small_logits(self):
r""" r"""
Logits testing to check implementation consistency between `t5x` implementation Logits testing to check implementation consistency between `t5x` implementation
and `transformers` implementation of Switch-C transformers. We only check the logits and `transformers` implementation of Switch-C transformers. We only check the logits
of the first batch. of the first batch.
""" """
model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).eval() model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).to(
input_ids = torch.ones((32, 64), dtype=torch.long) torch_device
decoder_input_ids = torch.ones((32, 64), dtype=torch.long) )
input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device)
decoder_input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device)
# fmt: off # fmt: off
EXPECTED_MEAN_LOGITS = torch.Tensor( EXPECTED_MEAN_LOGITS = torch.Tensor(
...@@ -1126,8 +1129,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase): ...@@ -1126,8 +1129,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
] ]
).to(torch.bfloat16) ).to(torch.bfloat16)
# fmt: on # fmt: on
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu()
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
hf_logits = hf_logits[0, 0, :30] hf_logits = hf_logits[0, 0, :30]
torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3) torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment