Unverified Commit 84c9bf74 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

cross platform from_pretrained (#20538)



* add support for `from_pt`

* add tf_flax utility file

* Update src/transformers/modeling_tf_flax_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove flax related modifications

* add test

* remove FLAX related commits

* fixup

* remove safetensor todos

* revert deletion
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 538e5248
...@@ -47,6 +47,7 @@ from .utils import ( ...@@ -47,6 +47,7 @@ from .utils import (
SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
ModelOutput, ModelOutput,
...@@ -2392,7 +2393,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2392,7 +2393,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
save directory. save directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory. configuration JSON file named *config.json* is found in the directory.
from_pt: (`bool`, *optional*, defaults to `False`): from_pt (`bool`, *optional*, defaults to `False`):
Load the model weights from a PyTorch state_dict save file (see docstring of Load the model weights from a PyTorch state_dict save file (see docstring of
`pretrained_model_name_or_path` argument). `pretrained_model_name_or_path` argument).
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
...@@ -2531,7 +2532,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2531,7 +2532,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if is_local:
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint in priority if from_pt # Load from a PyTorch checkpoint in priority if from_pt
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
...@@ -2559,7 +2560,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2559,7 +2560,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
is_sharded = True is_sharded = True
# At this stage we don't have a weight file so we will raise an error. # At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
...@@ -2630,6 +2633,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2630,6 +2633,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
) )
if resolved_archive_file is not None: if resolved_archive_file is not None:
is_sharded = True is_sharded = True
if resolved_archive_file is None and filename == WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None: if resolved_archive_file is None:
# Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error # Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
# message. # message.
...@@ -2646,8 +2656,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2646,8 +2656,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
) )
else: else:
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named" f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}." f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
) )
except EnvironmentError: except EnvironmentError:
...@@ -2661,7 +2671,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2661,7 +2671,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the" " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}." f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
) )
if is_local: if is_local:
logger.info(f"loading weights file {archive_file}") logger.info(f"loading weights file {archive_file}")
......
...@@ -2127,6 +2127,14 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -2127,6 +2127,14 @@ class UtilsFunctionsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, ref_model.weights): for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy()) assert np.allclose(p1.numpy(), p2.numpy())
@is_pt_tf_cross_test
def test_checkpoint_sharding_hub_from_pt(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
# the model above is the same as the model below, just a sharded pytorch version.
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
def test_shard_checkpoint(self): def test_shard_checkpoint(self):
# This is the model we will use, total size 340,000 bytes. # This is the model we will use, total size 340,000 bytes.
model = tf.keras.Sequential( model = tf.keras.Sequential(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment