"examples/trials/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "73a6990b1b4e74b7d6d837d2864b35317de5fd70"
Unverified Commit a03f7514 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix load from PT-formatted checkpoint in composite TF models (#20661)

* Fix load from PT-formatted checkpoint in composite TF models

* Leave the from_pt part as it was
parent 521da651
...@@ -2727,14 +2727,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2727,14 +2727,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return load_pytorch_checkpoint_in_tf2_model( return load_pytorch_checkpoint_in_tf2_model(
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info
) )
elif safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
state_dict = safe_load_file(resolved_archive_file)
# Load from a PyTorch checkpoint
return load_pytorch_state_dict_in_tf2_model(
model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info
)
# we might need to extend the variable scope for composite models # we might need to extend the variable scope for composite models
if load_weight_prefix is not None: if load_weight_prefix is not None:
...@@ -2743,6 +2735,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2743,6 +2735,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else: else:
model(model.dummy_inputs) # build the network with dummy inputs model(model.dummy_inputs) # build the network with dummy inputs
if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
state_dict = safe_load_file(resolved_archive_file)
# Load from a PyTorch checkpoint
return load_pytorch_state_dict_in_tf2_model(
model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info
)
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment