Unverified Commit 2bd2de62 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Sharding fails in TF when absolute scope was modified if `.` in layer name (#19124)

* simplify loop

* fix layer map split

* update

* update for special variables

* add rag test

* fixup

* revert change : for next PR
parent 614f7d28
...@@ -707,8 +707,15 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s ...@@ -707,8 +707,15 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
# the weight, we have to get rid of the first prefix of the name of the layer. # the weight, we have to get rid of the first prefix of the name of the layer.
model_keys = set("/".join(k.name.split("/")[1:]) for k in model.weights) model_keys = set()
model_layer_map = {"/".join(k.name.split("/")[1:]): i for i, k in enumerate(model.weights)} model_layer_map = dict()
for i, k in enumerate(model.weights):
if "model." in k.name or len(k.name.split("/")) == 1:
layer_name = k.name
else:
layer_name = "/".join(k.name.split("/")[1:])
model_keys.add(layer_name)
model_layer_map[layer_name] = i
for shard_file in shard_files: for shard_file in shard_files:
state_dict = tf.io.read_file(shard_file) state_dict = tf.io.read_file(shard_file)
...@@ -2211,17 +2218,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2211,17 +2218,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
) )
for shard_file, shard in shards.items(): for shard_file, shard in shards.items():
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file: with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
save_attributes_to_hdf5_group( layers = []
shard_file,
"layer_names",
["/".join(layer.name.split("/")[1:]).encode("utf8") for layer in shard],
)
for layer in sorted(shard, key=lambda x: x.name): for layer in sorted(shard, key=lambda x: x.name):
if "model." in layer.name or len(layer.name.split("/")) == 1:
layer_name = layer.name
print(layer_name)
else:
layer_name = "/".join(layer.name.split("/")[1:])
param_dset = shard_file.create_dataset( param_dset = shard_file.create_dataset(
"/".join(layer.name.split("/")[1:]), layer.numpy().shape, dtype=layer.numpy().dtype layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
) )
param_dset[:] = layer.numpy() param_dset[:] = layer.numpy()
layers.append(layer_name.encode("utf8"))
save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
if push_to_hub: if push_to_hub:
self._upload_modified_files( self._upload_modified_files(
......
...@@ -77,9 +77,11 @@ if is_tf_available(): ...@@ -77,9 +77,11 @@ if is_tf_available():
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig, BertConfig,
RagRetriever,
TFAutoModel, TFAutoModel,
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
TFBertModel, TFBertModel,
TFRagModel,
TFSharedEmbeddings, TFSharedEmbeddings,
) )
from transformers.generation_tf_utils import ( from transformers.generation_tf_utils import (
...@@ -2167,6 +2169,18 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -2167,6 +2169,18 @@ class UtilsFunctionsTest(unittest.TestCase):
}, },
) )
@slow
def test_special_layer_name_shardind(self):
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
with tempfile.TemporaryDirectory() as tmp_dir:
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size)
ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
def test_checkpoint_sharding_local(self): def test_checkpoint_sharding_local(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment