Unverified Commit b7fa1e3d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Use tiny models for get_pretrained_model in TFEncoderDecoderModelTest (#15989)



* Use tiny model for TFRembertEncoderDecoderModelTest.get_pretrained_model()
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8feede22
......@@ -509,8 +509,7 @@ class TFEncoderDecoderMixin:
model = TFEncoderDecoderModel(encoder_decoder_config)
model(**inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
def test_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size)
......@@ -542,7 +541,10 @@ class TFEncoderDecoderMixin:
@require_tf
class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-bert",
"hf-internal-testing/tiny-random-bert",
)
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFBertModel(config, name="encoder")
......@@ -637,7 +639,10 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
@require_tf
class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "../gpt2")
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-bert",
"hf-internal-testing/tiny-random-gpt2",
)
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFBertModel(config, name="encoder")
......@@ -726,7 +731,10 @@ class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
@require_tf
class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base")
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-roberta",
"hf-internal-testing/tiny-random-roberta",
)
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFRobertaModel(config, name="encoder")
......@@ -782,7 +790,10 @@ class TFRoBertaEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase)
@require_tf
class TFRembertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFEncoderDecoderModel.from_encoder_decoder_pretrained("google/rembert", "google/rembert")
return TFEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-rembert",
"hf-internal-testing/tiny-random-rembert",
)
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFRemBertModel(config, name="encoder")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment