Unverified Commit d07c771d authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Torchscript test for ConvBERT (#13352)

* Torchscript test for ConvBERT

* Apply suggestions from code review
parent 680733a7
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch ConvBERT model. """ """ Testing suite for the PyTorch ConvBERT model. """
import os
import tempfile
import unittest import unittest
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import ConvBertConfig, is_torch_available from transformers import ConvBertConfig, is_torch_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
...@@ -416,6 +416,29 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -416,6 +416,29 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
[self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length],
) )
@slow
@require_torch_gpu
def test_torchscript_device_change(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# ConvBertForMultipleChoice behaves incorrectly in JIT environments.
if model_class == ConvBertForMultipleChoice:
return
config.torchscript = True
model = model_class(config=config)
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
traced_model = torch.jit.trace(
model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu"))
)
with tempfile.TemporaryDirectory() as tmp:
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
@require_torch @require_torch
class ConvBertModelIntegrationTest(unittest.TestCase): class ConvBertModelIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment