Unverified Commit 6c24443f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Safetensors tf (#19900)



* Wip

* Add safetensors support for TensorFlow

* First tests

* Add final test for now

* Retrigger CI like this

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>
parent e4132952
...@@ -21,7 +21,7 @@ import re ...@@ -21,7 +21,7 @@ import re
import numpy import numpy
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size
from .utils import transpose as transpose_func from .utils import transpose as transpose_func
...@@ -273,7 +273,7 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -273,7 +273,7 @@ def load_pytorch_state_dict_in_tf2_model(
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape) array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape)
tf_loaded_numel += array.size tf_loaded_numel += tensor_size(array)
weight_value_tuples.append((symbolic_weight, array)) weight_value_tuples.append((symbolic_weight, array))
all_pytorch_weights.discard(name) all_pytorch_weights.discard(name)
......
...@@ -47,6 +47,8 @@ from .generation_tf_utils import TFGenerationMixin ...@@ -47,6 +47,8 @@ from .generation_tf_utils import TFGenerationMixin
from .tf_utils import shape_list from .tf_utils import shape_list
from .utils import ( from .utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
...@@ -59,12 +61,18 @@ from .utils import ( ...@@ -59,12 +61,18 @@ from .utils import (
has_file, has_file,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
is_safetensors_available,
logging, logging,
requires_backends, requires_backends,
working_or_temp_dir, working_or_temp_dir,
) )
if is_safetensors_available():
from safetensors import safe_open
from safetensors.tensorflow import load_file as safe_load_file
from safetensors.tensorflow import save_file as safe_save_file
if TYPE_CHECKING: if TYPE_CHECKING:
from . import PreTrainedTokenizerBase from . import PreTrainedTokenizerBase
...@@ -612,6 +620,14 @@ def dtype_byte_size(dtype): ...@@ -612,6 +620,14 @@ def dtype_byte_size(dtype):
return bit_size // 8 return bit_size // 8
def format_weight_name(name, _prefix=None):
if "model." not in name and len(name.split("/")) > 1:
name = "/".join(name.split("/")[1:])
if _prefix is not None:
name = _prefix + "/" + name
return name
def tf_shard_checkpoint(weights, max_shard_size="10GB"): def tf_shard_checkpoint(weights, max_shard_size="10GB"):
""" """
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
...@@ -849,6 +865,17 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, ...@@ -849,6 +865,17 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
mismatched layers. mismatched layers.
""" """
if resolved_archive_file.endswith(".safetensors"):
load_function = load_tf_weights_from_safetensors
else:
load_function = load_tf_weights_from_h5
return load_function(
model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, _prefix=_prefix
)
def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
missing_layers = [] missing_layers = []
unexpected_layers = [] unexpected_layers = []
mismatched_layers = [] mismatched_layers = []
...@@ -952,6 +979,47 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, ...@@ -952,6 +979,47 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
return missing_layers, unexpected_layers, mismatched_layers return missing_layers, unexpected_layers, mismatched_layers
def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
# Read the safetensors file
state_dict = safe_load_file(resolved_archive_file)
weight_value_tuples = []
mismatched_layers = []
weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights]
loaded_weight_names = list(state_dict.keys())
# Find the missing layers from the high level list of layers
missing_layers = list(set(weight_names) - set(loaded_weight_names))
# Find the unexpected layers from the high level list of layers
unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
weight_value_tuples = []
for weight in model.weights:
weight_name = format_weight_name(weight.name, _prefix=_prefix)
if weight_name in state_dict:
weight_value = state_dict[weight_name]
# Check if the shape of the current weight and the one from the H5 file are different
if K.int_shape(weight) != weight_value.shape:
# If yes we reshape the weight from the H5 file accordingly to the current weight
# If the two shapes are not compatible we raise an issue
try:
weight_value = tf.reshape(weight_value, K.int_shape(weight))
except ValueError as e:
if ignore_mismatched_sizes:
mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
continue
else:
raise e
weight_value_tuples.append((weight, weight_value))
# Load all the weights
K.batch_set_value(weight_value_tuples)
return missing_layers, unexpected_layers, mismatched_layers
def init_copy_embeddings(old_embeddings, new_num_tokens): def init_copy_embeddings(old_embeddings, new_num_tokens):
r""" r"""
This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case
...@@ -2118,6 +2186,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2118,6 +2186,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
signatures=None, signatures=None,
max_shard_size: Union[int, str] = "10GB", max_shard_size: Union[int, str] = "10GB",
create_pr: bool = False, create_pr: bool = False,
safe_serialization: bool = False,
**kwargs **kwargs
): ):
""" """
...@@ -2152,6 +2221,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2152,6 +2221,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
create_pr (`bool`, *optional*, defaults to `False`): create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit. Whether or not to create a PR with the uploaded files or directly commit.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
kwargs: kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
...@@ -2186,7 +2257,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2186,7 +2257,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
self.config.save_pretrained(save_directory) self.config.save_pretrained(save_directory)
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME) weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
output_model_file = os.path.join(save_directory, weights_name)
shards, index = tf_shard_checkpoint(self.weights, max_shard_size) shards, index = tf_shard_checkpoint(self.weights, max_shard_size)
...@@ -2195,15 +2267,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2195,15 +2267,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
full_filename = os.path.join(save_directory, filename) full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions. # in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
if ( if (
filename.startswith(TF2_WEIGHTS_NAME[:-4]) filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename) and os.path.isfile(full_filename)
and filename not in shards.keys() and filename not in shards.keys()
): ):
os.remove(full_filename) os.remove(full_filename)
if index is None: if index is None:
self.save_weights(output_model_file) if safe_serialization:
state_dict = {format_weight_name(w.name): w.value() for w in self.weights}
safe_save_file(state_dict, output_model_file, metadata={"format": "tf"})
else:
self.save_weights(output_model_file)
logger.info(f"Model weights saved in {output_model_file}") logger.info(f"Model weights saved in {output_model_file}")
else: else:
save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME) save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME)
...@@ -2427,6 +2504,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2427,6 +2504,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Load from a sharded PyTorch checkpoint # Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True is_sharded = True
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
):
# Load from a safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint # Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
...@@ -2457,7 +2546,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2457,7 +2546,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
resolved_archive_file = download_url(pretrained_model_name_or_path) resolved_archive_file = download_url(pretrained_model_name_or_path)
else: else:
# set correct filename # set correct filename
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME if from_pt:
filename = WEIGHTS_NAME
elif is_safetensors_available():
filename = SAFE_WEIGHTS_NAME
else:
filename = TF2_WEIGHTS_NAME
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
...@@ -2476,8 +2570,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2476,8 +2570,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
) )
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not. # result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
raise NotImplementedError(
"Support for sharded checkpoints using safetensors is coming soon!"
)
else:
# This repo has no safetensors file of any kind, we switch to TensorFlow.
filename = TF2_WEIGHTS_NAME
resolved_archive_file = cached_file(
pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs
)
if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME: if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file( resolved_archive_file = cached_file(
...@@ -2521,6 +2631,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2521,6 +2631,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if is_local: if is_local:
logger.info(f"loading weights file {archive_file}") logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file resolved_archive_file = archive_file
filename = resolved_archive_file.split(os.path.sep)[-1]
else: else:
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else: else:
...@@ -2543,6 +2654,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2543,6 +2654,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
_commit_hash=commit_hash, _commit_hash=commit_hash,
) )
safetensors_from_pt = False
if filename == SAFE_WEIGHTS_NAME:
with safe_open(resolved_archive_file, framework="tf") as f:
safetensors_metadata = f.metadata()
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
raise OSError(
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
" Make sure you save your model with the `save_pretrained` method."
)
safetensors_from_pt = safetensors_metadata.get("format") == "pt"
config.name_or_path = pretrained_model_name_or_path config.name_or_path = pretrained_model_name_or_path
# composed models, *e.g.* TFRag, require special treatment when it comes to loading # composed models, *e.g.* TFRag, require special treatment when it comes to loading
...@@ -2560,6 +2682,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2560,6 +2682,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return load_pytorch_checkpoint_in_tf2_model( return load_pytorch_checkpoint_in_tf2_model(
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info
) )
elif safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
state_dict = safe_load_file(resolved_archive_file)
# Load from a PyTorch checkpoint
return load_pytorch_state_dict_in_tf2_model(
model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info
)
# we might need to extend the variable scope for composite models # we might need to extend the variable scope for composite models
if load_weight_prefix is not None: if load_weight_prefix is not None:
......
...@@ -49,6 +49,7 @@ from .generic import ( ...@@ -49,6 +49,7 @@ from .generic import (
is_torch_tensor, is_torch_tensor,
reshape, reshape,
squeeze, squeeze,
tensor_size,
to_numpy, to_numpy,
to_py_obj, to_py_obj,
transpose, transpose,
......
...@@ -445,3 +445,19 @@ def expand_dims(array, axis): ...@@ -445,3 +445,19 @@ def expand_dims(array, axis):
return jnp.expand_dims(array, axis=axis) return jnp.expand_dims(array, axis=axis)
else: else:
raise ValueError(f"Type not supported for expand_dims: {type(array)}.") raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
def tensor_size(array):
"""
Framework-agnostic version of `numpy.size` that will work on torch/TensorFlow/Jax tensors as well as NumPy arrays.
"""
if is_numpy_array(array):
return np.size(array)
elif is_torch_tensor(array):
return array.numel()
elif is_tf_tensor(array):
return tf.size(array)
elif is_jax_tensor(array):
return array.size
else:
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
...@@ -43,13 +43,14 @@ from transformers.testing_utils import ( # noqa: F401 ...@@ -43,13 +43,14 @@ from transformers.testing_utils import ( # noqa: F401
_tf_gpu_memory_limit, _tf_gpu_memory_limit,
is_pt_tf_cross_test, is_pt_tf_cross_test,
is_staging_test, is_staging_test,
require_safetensors,
require_tf, require_tf,
require_tf2onnx, require_tf2onnx,
slow, slow,
tooslow, tooslow,
torch_device, torch_device,
) )
from transformers.utils import logging from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
from transformers.utils.generic import ModelOutput from transformers.utils.generic import ModelOutput
...@@ -94,12 +95,7 @@ if is_tf_available(): ...@@ -94,12 +95,7 @@ if is_tf_available():
TFSampleDecoderOnlyOutput, TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput, TFSampleEncoderDecoderOutput,
) )
from transformers.modeling_tf_utils import ( from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
tf_shard_checkpoint,
unpack_inputs,
)
from transformers.tf_utils import stable_softmax from transformers.tf_utils import stable_softmax
if _tf_gpu_memory_limit is not None: if _tf_gpu_memory_limit is not None:
...@@ -119,6 +115,8 @@ if is_tf_available(): ...@@ -119,6 +115,8 @@ if is_tf_available():
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import BertModel
def _config_zero_init(config): def _config_zero_init(config):
configs_no_init = copy.deepcopy(config) configs_no_init = copy.deepcopy(config)
...@@ -2168,7 +2166,7 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -2168,7 +2166,7 @@ class UtilsFunctionsTest(unittest.TestCase):
) )
@slow @slow
def test_special_layer_name_shardind(self): def test_special_layer_name_sharding(self):
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever) model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
...@@ -2268,6 +2266,54 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -2268,6 +2266,54 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys())) self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys())) self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
@require_safetensors
def test_safetensors_save_and_load(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
# No tf_model.h5 file, only a model.safetensors
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@is_pt_tf_cross_test
def test_safetensors_save_and_load_pt_to_tf(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
pt_model.save_pretrained(tmp_dir, safe_serialization=True)
# Check we have a model.safetensors file
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_hub(self):
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Can load from the TF-formatted checkpoint
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors-tf")
# Check models are equal
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
# Can load from the PyTorch-formatted checkpoint
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
# Check models are equal
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_tf @require_tf
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment