Unverified Commit b7fa1e3d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Use tiny models for get_pretrained_model in TFEncoderDecoderModelTest (#15989)



* Use tiny model for TFRembertEncoderDecoderModelTest.get_pretrained_model()
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8feede22
...@@ -509,8 +509,7 @@ class TFEncoderDecoderMixin: ...@@ -509,8 +509,7 @@ class TFEncoderDecoderMixin:
model = TFEncoderDecoderModel(encoder_decoder_config) model = TFEncoderDecoderModel(encoder_decoder_config)
model(**inputs_dict) model(**inputs_dict)
@slow def test_model_save_load_from_pretrained(self):
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model() model_2 = self.get_pretrained_model()
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size) input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size) decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size)
...@@ -542,7 +541,10 @@ class TFEncoderDecoderMixin: ...@@ -542,7 +541,10 @@ class TFEncoderDecoderMixin:
@require_tf @require_tf
class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self): def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-bert",
"hf-internal-testing/tiny-random-bert",
)
def get_encoder_decoder_model(self, config, decoder_config): def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFBertModel(config, name="encoder") encoder_model = TFBertModel(config, name="encoder")
...@@ -637,7 +639,10 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): ...@@ -637,7 +639,10 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
@require_tf @require_tf
class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self): def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "../gpt2") return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-bert",
"hf-internal-testing/tiny-random-gpt2",
)
def get_encoder_decoder_model(self, config, decoder_config): def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFBertModel(config, name="encoder") encoder_model = TFBertModel(config, name="encoder")
...@@ -726,7 +731,10 @@ class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): ...@@ -726,7 +731,10 @@ class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
@require_tf @require_tf
class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self): def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base") return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-roberta",
"hf-internal-testing/tiny-random-roberta",
)
def get_encoder_decoder_model(self, config, decoder_config): def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFRobertaModel(config, name="encoder") encoder_model = TFRobertaModel(config, name="encoder")
...@@ -782,7 +790,10 @@ class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase) ...@@ -782,7 +790,10 @@ class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase)
@require_tf @require_tf
class TFRembertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): class TFRembertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self): def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("google/rembert", "google/rembert") return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-rembert",
"hf-internal-testing/tiny-random-rembert",
)
def get_encoder_decoder_model(self, config, decoder_config): def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFRemBertModel(config, name="encoder") encoder_model = TFRemBertModel(config, name="encoder")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment