Unverified Commit 71d47f0a authored by Matt's avatar Matt Committed by GitHub
Browse files

More TF fixes (#28081)

* More build_in_name_scope()

* Make sure we set the save spec now we don't do it with dummies anymore

* make fixup
parent 0695b242
......@@ -1147,6 +1147,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
self.config = config
self.name_or_path = config.name_or_path
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
self._set_save_spec(self.input_signature)
def get_config(self):
return self.config.to_dict()
......
......@@ -211,7 +211,7 @@ class TFAutoModelTest(unittest.TestCase):
config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"]
model = TFAutoModel.from_config(config)
model.build()
model.build_in_name_scope()
self.assertIsInstance(model, TFFunnelBaseModel)
......@@ -249,7 +249,7 @@ class TFAutoModelTest(unittest.TestCase):
config = NewModelConfig(**tiny_config.to_dict())
model = auto_class.from_config(config)
model.build()
model.build_in_name_scope()
self.assertIsInstance(model, TFNewModel)
......
......@@ -445,7 +445,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
continue
model = model_class(config)
model.build()
model.build_in_name_scope()
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
......
......@@ -312,7 +312,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
model.build()
model.build_in_name_scope()
embeds = model.get_encoder().embed_positions.get_weights()[0]
sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
......
......@@ -217,7 +217,7 @@ class TFCoreModelTesterMixin:
for model_class in self.all_model_classes[:2]:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
model.build()
model.build_in_name_scope()
num_out = len(model(class_inputs_dict))
for key in list(class_inputs_dict.keys()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment