Unverified Commit 91e1f24e authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

CLI: convert sharded PT models (#17959)

* sharded conversion; add flag to control max hidden error

* better hidden name matching

* Add test: load TF from PT shards

* fix test (PT data must be local)
parent f25457b2
...@@ -34,7 +34,7 @@ from .. import ( ...@@ -34,7 +34,7 @@ from .. import (
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
) )
from ..utils import logging from ..utils import TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
...@@ -48,7 +48,6 @@ if is_torch_available(): ...@@ -48,7 +48,6 @@ if is_torch_available():
MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors
TF_WEIGHTS_NAME = "tf_model.h5"
def convert_command_factory(args: Namespace): def convert_command_factory(args: Namespace):
...@@ -58,7 +57,13 @@ def convert_command_factory(args: Namespace): ...@@ -58,7 +57,13 @@ def convert_command_factory(args: Namespace):
Returns: ServeCommand Returns: ServeCommand
""" """
return PTtoTFCommand( return PTtoTFCommand(
args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push, args.extra_commit_description args.model_name,
args.local_dir,
args.max_hidden_error,
args.new_weights,
args.no_pr,
args.push,
args.extra_commit_description,
) )
...@@ -90,6 +95,15 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -90,6 +95,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
default="", default="",
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}", help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
) )
train_parser.add_argument(
"--max-hidden-error",
type=float,
default=MAX_ERROR,
help=(
f"Maximum error tolerance for hidden layer outputs. Defaults to {MAX_ERROR}. If you suspect the hidden"
" layers outputs will be used for downstream applications, avoid increasing this tolerance."
),
)
train_parser.add_argument( train_parser.add_argument(
"--new-weights", "--new-weights",
action="store_true", action="store_true",
...@@ -112,14 +126,10 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -112,14 +126,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
train_parser.set_defaults(func=convert_command_factory) train_parser.set_defaults(func=convert_command_factory)
@staticmethod @staticmethod
def find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input): def find_pt_tf_differences(pt_outputs, tf_outputs):
""" """
Compares the TensorFlow and PyTorch models, given their inputs, returning a dictionary with all tensor Compares the TensorFlow and PyTorch outputs, returning a dictionary with all tensor differences.
differences.
""" """
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
tf_outputs = tf_model(**tf_input, output_hidden_states=True)
# 1. All output attributes must be the same # 1. All output attributes must be the same
pt_out_attrs = set(pt_outputs.keys()) pt_out_attrs = set(pt_outputs.keys())
tf_out_attrs = set(tf_outputs.keys()) tf_out_attrs = set(tf_outputs.keys())
...@@ -158,6 +168,7 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -158,6 +168,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self, self,
model_name: str, model_name: str,
local_dir: str, local_dir: str,
max_hidden_error: float,
new_weights: bool, new_weights: bool,
no_pr: bool, no_pr: bool,
push: bool, push: bool,
...@@ -167,6 +178,7 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -167,6 +178,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self._logger = logging.get_logger("transformers-cli/pt_to_tf") self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name) self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._max_hidden_error = max_hidden_error
self._new_weights = new_weights self._new_weights = new_weights
self._no_pr = no_pr self._no_pr = no_pr
self._push = push self._push = push
...@@ -260,34 +272,49 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -260,34 +272,49 @@ class PTtoTFCommand(BaseTransformersCLICommand):
pt_model = pt_class.from_pretrained(self._local_dir) pt_model = pt_class.from_pretrained(self._local_dir)
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True) tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
pt_input, tf_input = self.get_inputs(pt_model, config) pt_input, tf_input = self.get_inputs(pt_model, config)
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
del pt_model # will no longer be used, and may have a large memory footprint
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True)
# Confirms that cross loading PT weights into TF worked. # Confirms that cross loading PT weights into TF worked.
crossload_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_from_pt_model, tf_input) crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
max_crossload_diff = max(crossload_differences.values()) output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k}
if max_crossload_diff > MAX_ERROR: hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k}
max_crossload_output_diff = max(output_differences.values())
max_crossload_hidden_diff = max(hidden_differences.values())
if max_crossload_output_diff > MAX_ERROR or max_crossload_hidden_diff > self._max_hidden_error:
raise ValueError( raise ValueError(
"The cross-loaded TensorFlow model has different outputs, something went wrong! Exaustive list of" "The cross-loaded TensorFlow model has different outputs, something went wrong!\n"
f" maximum tensor differences above the error threshold ({MAX_ERROR}):\n" + f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n"
+ "\n".join( + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR])
[f"{key}: {value:.3e}" for key, value in crossload_differences.items() if value > MAX_ERROR] + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n"
) + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error])
) )
# Save the weights in a TF format (if needed) and confirms that the results are still good # Save the weights in a TF format (if needed) and confirms that the results are still good
tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME) tf_weights_path = os.path.join(self._local_dir, TF2_WEIGHTS_NAME)
if not os.path.exists(tf_weights_path) or self._new_weights: tf_weights_index_path = os.path.join(self._local_dir, TF2_WEIGHTS_INDEX_NAME)
tf_from_pt_model.save_weights(tf_weights_path) if (not os.path.exists(tf_weights_path) and not os.path.exists(tf_weights_index_path)) or self._new_weights:
tf_from_pt_model.save_pretrained(self._local_dir)
del tf_from_pt_model # will no longer be used, and may have a large memory footprint del tf_from_pt_model # will no longer be used, and may have a large memory footprint
tf_model = tf_class.from_pretrained(self._local_dir) tf_model = tf_class.from_pretrained(self._local_dir)
conversion_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input) tf_outputs = tf_model(**tf_input, output_hidden_states=True)
max_conversion_diff = max(conversion_differences.values())
if max_conversion_diff > MAX_ERROR: conversion_differences = self.find_pt_tf_differences(pt_outputs, tf_outputs)
output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k}
hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k}
max_conversion_output_diff = max(output_differences.values())
max_conversion_hidden_diff = max(hidden_differences.values())
if max_conversion_output_diff > MAX_ERROR or max_conversion_hidden_diff > self._max_hidden_error:
raise ValueError( raise ValueError(
"The converted TensorFlow model has different outputs, something went wrong! Exaustive list of maximum" "The converted TensorFlow model has different outputs, something went wrong!\n"
f" tensor differences above the error threshold ({MAX_ERROR}):\n" + f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n"
+ "\n".join( + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR])
[f"{key}: {value:.3e}" for key, value in conversion_differences.items() if value > MAX_ERROR] + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n"
) + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error])
) )
commit_message = "Update TF weights" if self._new_weights else "Add TF weights" commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
...@@ -300,16 +327,31 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -300,16 +327,31 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self._logger.warn("Uploading the weights into a new PR...") self._logger.warn("Uploading the weights into a new PR...")
commit_descrition = ( commit_descrition = (
"Model converted by the [`transformers`' `pt_to_tf`" "Model converted by the [`transformers`' `pt_to_tf`"
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py)." " CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py). "
"\n\nAll converted model outputs and hidden layers were validated against its Pytorch counterpart." "All converted model outputs and hidden layers were validated against its Pytorch counterpart.\n\n"
f" Maximum crossload output difference={max_crossload_diff:.3e}; Maximum converted output" f"Maximum crossload output difference={max_crossload_output_diff:.3e}; "
f" difference={max_conversion_diff:.3e}." f"Maximum crossload hidden layer difference={max_crossload_hidden_diff:.3e};\n"
f"Maximum conversion output difference={max_conversion_output_diff:.3e}; "
f"Maximum conversion hidden layer difference={max_conversion_hidden_diff:.3e};\n"
) )
if self._extra_commit_description: if self._extra_commit_description:
commit_descrition += "\n\n" + self._extra_commit_description commit_descrition += "\n\n" + self._extra_commit_description
# sharded model -> adds all related files (index and .h5 shards)
if os.path.exists(tf_weights_index_path):
operations = [
CommitOperationAdd(path_in_repo=TF2_WEIGHTS_INDEX_NAME, path_or_fileobj=tf_weights_index_path)
]
for shard_path in tf.io.gfile.glob(self._local_dir + "/tf_model-*.h5"):
operations += [
CommitOperationAdd(path_in_repo=os.path.basename(shard_path), path_or_fileobj=shard_path)
]
else:
operations = [CommitOperationAdd(path_in_repo=TF2_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)]
hub_pr_url = create_commit( hub_pr_url = create_commit(
repo_id=self._model_name, repo_id=self._model_name,
operations=[CommitOperationAdd(path_in_repo=TF_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)], operations=operations,
commit_message=commit_message, commit_message=commit_message,
commit_description=commit_descrition, commit_description=commit_descrition,
repo_type="model", repo_type="model",
......
...@@ -117,10 +117,17 @@ def load_pytorch_checkpoint_in_tf2_model( ...@@ -117,10 +117,17 @@ def load_pytorch_checkpoint_in_tf2_model(
) )
raise raise
pt_path = os.path.abspath(pytorch_checkpoint_path) # Treats a single file as a collection of shards with 1 shard.
if isinstance(pytorch_checkpoint_path, str):
pytorch_checkpoint_path = [pytorch_checkpoint_path]
# Loads all shards into a single state dictionary
pt_state_dict = {}
for path in pytorch_checkpoint_path:
pt_path = os.path.abspath(path)
logger.info(f"Loading PyTorch weights from {pt_path}") logger.info(f"Loading PyTorch weights from {pt_path}")
pt_state_dict.update(torch.load(pt_path, map_location="cpu"))
pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters") logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
return load_pytorch_weights_in_tf2_model( return load_pytorch_weights_in_tf2_model(
......
...@@ -50,6 +50,7 @@ from .utils import ( ...@@ -50,6 +50,7 @@ from .utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
EntryNotFoundError, EntryNotFoundError,
ModelOutput, ModelOutput,
...@@ -2157,11 +2158,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2157,11 +2158,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint in priority if from_pt # Load from a PyTorch checkpoint in priority if from_pt
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint # Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint # Load from a sharded TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
is_sharded = True is_sharded = True
# At this stage we don't have a weight file so we will raise an error. # At this stage we don't have a weight file so we will raise an error.
......
...@@ -27,7 +27,7 @@ from typing import List, Tuple ...@@ -27,7 +27,7 @@ from typing import List, Tuple
from datasets import Dataset from datasets import Dataset
from huggingface_hub import HfFolder, delete_repo, set_access_token from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
...@@ -1966,6 +1966,16 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1966,6 +1966,16 @@ class UtilsFunctionsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, ref_model.weights): for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy()) assert np.allclose(p1.numpy(), p2.numpy())
@is_pt_tf_cross_test
def test_checkpoint_sharding_local_from_pt(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
# the model above is the same as the model below, just a sharded pytorch version.
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
def test_shard_checkpoint(self): def test_shard_checkpoint(self):
# This is the model we will use, total size 340,000 bytes. # This is the model we will use, total size 340,000 bytes.
model = tf.keras.Sequential( model = tf.keras.Sequential(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment