Unverified Commit 7bb6933b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: standardize `test_model_common_attributes` for language models (#23457)

parent 4ed07528
...@@ -1013,7 +1013,7 @@ class TFModelTesterMixin: ...@@ -1013,7 +1013,7 @@ class TFModelTesterMixin:
check_hidden_states_output(config, inputs_dict, model_class) check_hidden_states_output(config, inputs_dict, model_class)
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
text_in_text_out_models = ( text_in_text_out_models = (
get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING) get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING)
+ get_values(TF_MODEL_FOR_MASKED_LM_MAPPING) + get_values(TF_MODEL_FOR_MASKED_LM_MAPPING)
...@@ -1023,24 +1023,27 @@ class TFModelTesterMixin: ...@@ -1023,24 +1023,27 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) self.assertIsInstance(model.get_input_embeddings(), tf.keras.layers.Layer)
if model_class in text_in_text_out_models:
x = model.get_output_embeddings() legacy_text_in_text_out = model.get_lm_head() is not None
assert isinstance(x, tf.keras.layers.Layer) if model_class in text_in_text_out_models or legacy_text_in_text_out:
name = model.get_bias() out_embeddings = model.get_output_embeddings()
assert isinstance(name, dict) self.assertIsInstance(out_embeddings, tf.keras.layers.Layer)
for k, v in name.items(): bias = model.get_bias()
assert isinstance(v, tf.Variable) if bias is not None:
self.assertIsInstance(bias, dict)
for _, v in bias.items():
self.assertIsInstance(v, tf.Variable)
elif model_class in speech_in_text_out_models: elif model_class in speech_in_text_out_models:
x = model.get_output_embeddings() out_embeddings = model.get_output_embeddings()
assert isinstance(x, tf.keras.layers.Layer) self.assertIsInstance(out_embeddings, tf.keras.layers.Layer)
name = model.get_bias() bias = model.get_bias()
assert name is None self.assertIsNone(bias)
else: else:
x = model.get_output_embeddings() out_embeddings = model.get_output_embeddings()
assert x is None assert out_embeddings is None
name = model.get_bias() bias = model.get_bias()
assert name is None self.assertIsNone(bias)
def test_determinism(self): def test_determinism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment