Unverified Commit 0d158e38 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[fix] mobilebert had wrong path, causing slow test failure (#5205)

parent f5c2a122
...@@ -39,7 +39,7 @@ from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, p ...@@ -39,7 +39,7 @@ from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, p
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["mobilebert-uncased"] MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["google/mobilebert-uncased"]
def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path): def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
......
...@@ -454,12 +454,6 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -454,12 +454,6 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mobilebert_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_mobilebert_for_token_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = MobileBertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def _long_tensor(tok_lst): def _long_tensor(tok_lst):
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device,) return torch.tensor(tok_lst, dtype=torch.long, device=torch_device,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment