"vscode:/vscode.git/clone" did not exist on "8a312956fd49efd69adb98c40996719d4c276a01"
modeling_tf_utils.py 126 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
Julien Plu's avatar
Julien Plu committed
17

18
import functools
Arthur's avatar
Arthur committed
19
import gc
Julien Plu's avatar
Julien Plu committed
20
import inspect
Arthur's avatar
Arthur committed
21
import json
thomwolf's avatar
thomwolf committed
22
import os
23
import pickle
24
import re
Julien Plu's avatar
Julien Plu committed
25
import warnings
26
from collections.abc import Mapping
27
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
thomwolf's avatar
thomwolf committed
28

Aymeric Augustin's avatar
Aymeric Augustin committed
29
import h5py
Julien Chaumond's avatar
Julien Chaumond committed
30
import numpy as np
thomwolf's avatar
thomwolf committed
31
import tensorflow as tf
Julien Plu's avatar
Julien Plu committed
32
from tensorflow.python.keras import backend as K
Matt's avatar
Matt committed
33
from tensorflow.python.keras.engine import data_adapter
34
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
thomwolf's avatar
thomwolf committed
35
from tensorflow.python.keras.saving import hdf5_format
thomwolf's avatar
thomwolf committed
36

37
from huggingface_hub import Repository, list_repo_files
Arthur's avatar
Arthur committed
38
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
39
from requests import HTTPError
Arthur's avatar
Arthur committed
40
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
41

42
from . import DataCollatorWithPadding, DefaultDataCollator
43
from .activations_tf import get_tf_activation
thomwolf's avatar
thomwolf committed
44
from .configuration_utils import PretrainedConfig
45
from .dynamic_module_utils import custom_object_save
46
47
48
from .generation_tf_utils import TFGenerationMixin
from .tf_utils import shape_list
from .utils import (
Julien Plu's avatar
Julien Plu committed
49
    DUMMY_INPUTS,
50
    HUGGINGFACE_CO_RESOLVE_ENDPOINT,
Arthur's avatar
Arthur committed
51
    TF2_WEIGHTS_INDEX_NAME,
Julien Plu's avatar
Julien Plu committed
52
    TF2_WEIGHTS_NAME,
53
    WEIGHTS_INDEX_NAME,
Julien Plu's avatar
Julien Plu committed
54
    WEIGHTS_NAME,
55
    EntryNotFoundError,
Julien Plu's avatar
Julien Plu committed
56
    ModelOutput,
Sylvain Gugger's avatar
Sylvain Gugger committed
57
    PushToHubMixin,
58
59
    RepositoryNotFoundError,
    RevisionNotFoundError,
Julien Plu's avatar
Julien Plu committed
60
    cached_path,
61
    copy_func,
62
    find_labels,
63
    has_file,
Julien Plu's avatar
Julien Plu committed
64
    hf_bucket_url,
65
    is_offline_mode,
Julien Plu's avatar
Julien Plu committed
66
    is_remote_url,
67
    logging,
68
    requires_backends,
Julien Plu's avatar
Julien Plu committed
69
)
thomwolf's avatar
thomwolf committed
70

Aymeric Augustin's avatar
Aymeric Augustin committed
71

72
73
74
75
if TYPE_CHECKING:
    from . import PreTrainedTokenizerBase


Lysandre Debut's avatar
Lysandre Debut committed
76
logger = logging.get_logger(__name__)
77
tf_logger = tf.get_logger()
thomwolf's avatar
thomwolf committed
78

Julien Plu's avatar
Julien Plu committed
79
TFModelInputType = Union[
80
81
82
83
84
85
86
87
88
    List[tf.Tensor],
    List[np.ndarray],
    List[KerasTensor],
    Dict[str, tf.Tensor],
    Dict[str, np.ndarray],
    Dict[str, KerasTensor],
    tf.Tensor,
    np.ndarray,
    KerasTensor,
Julien Plu's avatar
Julien Plu committed
89
90
]

91

Matt's avatar
Matt committed
92
93
94
95
def dummy_loss(y_true, y_pred):
    return tf.reduce_mean(y_pred)


96
class TFModelUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
97
    """
98
    A few utilities for `tf.keras.Model`, to be used as a mixin.
Julien Chaumond's avatar
Julien Chaumond committed
99
100
101
102
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
103
104
105
        Get the number of (optionally, trainable) parameters in the model.

        Args:
106
            only_trainable (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
107
108
109
                Whether or not to return only the number of trainable parameters

        Returns:
110
            `int`: The number of parameters.
Julien Chaumond's avatar
Julien Chaumond committed
111
112
113
114
115
116
117
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


118
def keras_serializable(cls):
119
120
121
122
    """
    Decorate a Keras Layer class to support Keras serialization.

    This is done by:
Sylvain Gugger's avatar
Sylvain Gugger committed
123

124
    1. Adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
Sylvain Gugger's avatar
Sylvain Gugger committed
125
       serialization time.
Sylvain Gugger's avatar
Sylvain Gugger committed
126
127
    2. Wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
       convert it to a config object for the actual layer initializer.
Sylvain Gugger's avatar
Sylvain Gugger committed
128
    3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not
129
       need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model`.
Sylvain Gugger's avatar
Sylvain Gugger committed
130
131

    Args:
132
        cls (a `tf.keras.layers.Layers subclass`):
Sylvain Gugger's avatar
Sylvain Gugger committed
133
134
            Typically a `TF.MainLayer` class in this project, in general must accept a `config` argument to its
            initializer.
Sylvain Gugger's avatar
Sylvain Gugger committed
135
136
137

    Returns:
        The same class object, with modifications for Keras deserialization.
138
    """
139
    initializer = cls.__init__
140

141
142
143
144
    config_class = getattr(cls, "config_class", None)
    if config_class is None:
        raise AttributeError("Must set `config_class` to use @keras_serializable")

145
    @functools.wraps(initializer)
146
    def wrapped_init(self, *args, **kwargs):
147
148
149
150
        config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None)

        if isinstance(config, dict):
            config = config_class.from_dict(config)
151
            initializer(self, config, *args, **kwargs)
152
153
154
155
156
        elif isinstance(config, PretrainedConfig):
            if len(args) > 0:
                initializer(self, *args, **kwargs)
            else:
                initializer(self, config, *args, **kwargs)
157
        else:
158
159
160
            raise ValueError("Must pass either `config` (PretrainedConfig) or `config` (dict)")

        self._config = config
Julien Plu's avatar
Julien Plu committed
161
        self._kwargs = kwargs
162

163
164
165
166
167
168
169
170
    cls.__init__ = wrapped_init

    if not hasattr(cls, "get_config"):
        raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
    if hasattr(cls.get_config, "_is_default"):

        def get_config(self):
            cfg = super(cls, self).get_config()
171
            cfg["config"] = self._config.to_dict()
Julien Plu's avatar
Julien Plu committed
172
            cfg.update(self._kwargs)
173
174
175
176
            return cfg

        cls.get_config = get_config

177
    cls._keras_serializable = True
178
179
180
    if hasattr(tf.keras.utils, "register_keras_serializable"):
        cls = tf.keras.utils.register_keras_serializable()(cls)
    return cls
181
182


183
class TFCausalLanguageModelingLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
184
185
186
    """
    Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.

187
    <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
188

189
    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
Sylvain Gugger's avatar
Sylvain Gugger committed
190

191
    </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
192
193
    """

194
    def hf_compute_loss(self, labels, logits):
195
196
197
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
Matt's avatar
Matt committed
198
199
200
201
202
203
204
205
206
        if self.config.tf_legacy_loss:
            # make sure only labels that are not equal to -100 affect the loss
            active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
            reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
            labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
            return loss_fn(labels, reduced_logits)

        # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
        unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
Muennighoff's avatar
Muennighoff committed
207
        # make sure only labels that are not equal to -100 affect the loss
Matt's avatar
Matt committed
208
209
        loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
        masked_loss = unmasked_loss * loss_mask
210
211
        reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
        return tf.reshape(reduced_masked_loss, (1,))
212
213


Julien Plu's avatar
Julien Plu committed
214
class TFQuestionAnsweringLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
215
    """
216
    Loss function suitable for question answering.
Sylvain Gugger's avatar
Sylvain Gugger committed
217
218
    """

219
    def hf_compute_loss(self, labels, logits):
Julien Plu's avatar
Julien Plu committed
220
221
222
223
224
225
226
227
228
229
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        start_loss = loss_fn(labels["start_position"], logits[0])
        end_loss = loss_fn(labels["end_position"], logits[1])

        return (start_loss + end_loss) / 2.0


class TFTokenClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
230
231
232
    """
    Loss function suitable for token classification.

233
    <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
234

235
    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
Sylvain Gugger's avatar
Sylvain Gugger committed
236

237
    </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
238
239
    """

240
    def hf_compute_loss(self, labels, logits):
Julien Plu's avatar
Julien Plu committed
241
242
243
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
Matt's avatar
Matt committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        if tf.executing_eagerly():  # Data-dependent conditionals are forbidden in XLA
            if tf.math.reduce_any(labels == -1):
                tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")

        if self.config.tf_legacy_loss:
            # make sure only labels that are not equal to -100
            # are taken into account as loss
            if tf.math.reduce_any(labels == -1):
                tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
                active_loss = tf.reshape(labels, (-1,)) != -1
            else:
                active_loss = tf.reshape(labels, (-1,)) != -100
            reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
            labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)

            return loss_fn(labels, reduced_logits)
Julien Plu's avatar
Julien Plu committed
260

Matt's avatar
Matt committed
261
262
263
264
265
266
267
268
        # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
        unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
        # make sure only labels that are not equal to -100 or -1
        # are taken into account as loss
        loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype)
        # Avoid possible division by zero later
        # Masked positions will have a loss of NaN because -100 and -1 are not valid labels
        masked_loss = unmasked_loss * loss_mask
269
270
        reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
        return tf.reshape(reduced_masked_loss, (1,))
Julien Plu's avatar
Julien Plu committed
271
272
273


class TFSequenceClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
274
275
276
277
    """
    Loss function suitable for sequence classification.
    """

278
    def hf_compute_loss(self, labels, logits):
Matt's avatar
Matt committed
279
        if logits.shape.rank == 1 or logits.shape[1] == 1:
Julien Plu's avatar
Julien Plu committed
280
281
282
283
284
285
286
287
288
            loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
        else:
            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True, reduction=tf.keras.losses.Reduction.NONE
            )

        return loss_fn(labels, logits)


Matt's avatar
Matt committed
289
class TFMultipleChoiceLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
290
291
    """Loss function suitable for multiple choice tasks."""

292
    def hf_compute_loss(self, labels, logits):
Matt's avatar
Matt committed
293
294
295
296
297
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        return loss_fn(labels, logits)

Sylvain Gugger's avatar
Sylvain Gugger committed
298
299
300

class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
    """
Lysandre's avatar
Lysandre committed
301
    Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.
Sylvain Gugger's avatar
Sylvain Gugger committed
302

303
    <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
304

305
306
307
    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    </Tip>
Lysandre's avatar
Lysandre committed
308
    """
Julien Plu's avatar
Julien Plu committed
309
310


311
312
313
314
class TFNextSentencePredictionLoss:
    """
    Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.

315
316
317
318
319
    <Tip>

    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    </Tip>
320
321
    """

322
    def hf_compute_loss(self, labels, logits):
323
324
325
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
Matt's avatar
Matt committed
326
327
328
329
330
331
332
333
334
        if self.config.tf_legacy_loss:
            # make sure only labels that are not equal to -100
            # are taken into account as loss
            next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
            next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
            next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)

            return loss_fn(next_sentence_label, next_sentence_reduced_logits)

335
336
337
        # make sure only labels that are not equal to -100
        # are taken into account as loss

Matt's avatar
Matt committed
338
339
340
341
342
343
344
        # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
        unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits)
        ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype)
        # Just zero out samples where label is -100, no reduction
        masked_ns_loss = unmasked_ns_loss * ns_loss_mask

        return masked_ns_loss
345
346


347
348
349
350
351
352
def booleans_processing(config, **kwargs):
    """
    Process the input booleans of each model in order to be sure they are compliant with the execution mode (eager or
    graph)

    Args:
353
        config ([`PretrainedConfig`]):
354
355
356
357
358
359
360
361
362
363
            The config of the running model.
        **kwargs:
            The boolean parameters

    Returns:
        A dictionary with the proper values for each boolean
    """
    final_booleans = {}

    if tf.executing_eagerly():
364
365
366
367
368
369
        # Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has
        # `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`)
        if "output_attentions" in kwargs:
            final_booleans["output_attentions"] = (
                kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
            )
370
371
372
373
374
        final_booleans["output_hidden_states"] = (
            kwargs["output_hidden_states"]
            if kwargs["output_hidden_states"] is not None
            else config.output_hidden_states
        )
Julien Plu's avatar
Julien Plu committed
375
376
377
        final_booleans["return_dict"] = (
            kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict
        )
378
379

        if "use_cache" in kwargs:
380
381
382
            final_booleans["use_cache"] = (
                kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
            )
383
    else:
384
385
386
387
        # Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has
        # `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`)
        if "output_attentions" in kwargs:
            final_booleans["output_attentions"] = config.output_attentions
388
389
        final_booleans["output_hidden_states"] = config.output_hidden_states

390
        if kwargs.get("return_dict", None) not in (None, True):
391
392
393
            tf_logger.warning(
                "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`."
            )
Julien Plu's avatar
Julien Plu committed
394
        final_booleans["return_dict"] = True
395
396

        if "use_cache" in kwargs:
397
            final_booleans["use_cache"] = getattr(config, "use_cache", None)
398
399
400
401

    return final_booleans


402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def unpack_inputs(func):
    """
    Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables
    downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input
    (common case in Keras).

    Args:
        func (`callable`):
            The callable function of the TensorFlow model.

    Returns:
        A callable that wraps the original `func` with the behavior described above.
    """

    original_signature = inspect.signature(func)

    @functools.wraps(func)
    def run_call_with_unpacked_inputs(self, *args, **kwargs):
        # isolates the actual `**kwargs` for the decorated function
        kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)}
        fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call}
        fn_args_and_kwargs.update({"kwargs_call": kwargs_call})

        # move any arg into kwargs, if they exist
        fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))

        # process the inputs and call the wrapped function
        main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1])
430
        main_input = fn_args_and_kwargs.pop(main_input_name, None)
431
432
433
434
435
436
437
438
439
440
441
        unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs)
        return func(self, **unpacked_inputs)

    # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
    # function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below
    # Keras would attempt to check the first argument against the literal signature of the wrapper.
    run_call_with_unpacked_inputs.__signature__ = original_signature

    return run_call_with_unpacked_inputs


442
443
def input_processing(func, config, input_ids, **kwargs):
    """
Julien Plu's avatar
Julien Plu committed
444
445
446
    Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
    has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32',
    name="input_ids")` otherwise the order of the tensors will not be guaranteed during the training.
447
448

    Args:
449
        func (`callable`):
450
            The callable function of the TensorFlow model.
451
        config ([`PretrainedConfig`]):
452
453
454
455
456
457
458
            The config of the running model.
        **kwargs:
            The inputs of the model.

    Returns:
        Two lists, one for the missing layers, and another one for the unexpected layers.
    """
Julien Plu's avatar
Julien Plu committed
459
    signature = dict(inspect.signature(func).parameters)
460
    has_kwargs = bool(signature.pop("kwargs", None))
Julien Plu's avatar
Julien Plu committed
461
    signature.pop("self", None)
Julien Plu's avatar
Julien Plu committed
462
463
    parameter_names = list(signature.keys())
    output = {}
464
    allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor)
Julien Plu's avatar
Julien Plu committed
465
466
467
468
469
470
471
472
473
474
475

    if "inputs" in kwargs["kwargs_call"]:
        warnings.warn(
            "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
            FutureWarning,
        )

        output["input_ids"] = kwargs["kwargs_call"].pop("inputs")

    if "decoder_cached_states" in kwargs["kwargs_call"]:
        warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
476
477
            "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
            " `past_key_values` instead.",
Julien Plu's avatar
Julien Plu committed
478
479
480
481
            FutureWarning,
        )
        output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")

482
    if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names:
483
        warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
484
485
            "The `past` argument is deprecated and will be removed in a future version, use `past_key_values`"
            " instead.",
486
487
488
            FutureWarning,
        )
        kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
489
    elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names:
490
491
        kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")

492
493
494
495
496
    if has_kwargs:
        output["kwargs"] = kwargs.pop("kwargs_call", {})
    else:
        if len(kwargs["kwargs_call"]) > 0:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
497
498
                "The following keyword arguments are not supported by this model:"
                f" {list(kwargs['kwargs_call'].keys())}."
499
500
            )
        kwargs.pop("kwargs_call")
Julien Plu's avatar
Julien Plu committed
501

Julien Plu's avatar
Julien Plu committed
502
503
504
505
    for k, v in kwargs.items():
        if isinstance(v, allowed_types) or v is None:
            output[k] = v
        else:
Julien Plu's avatar
Julien Plu committed
506
            raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
Julien Plu's avatar
Julien Plu committed
507
508
509
510
511

    if isinstance(input_ids, (tuple, list)):
        for i, input in enumerate(input_ids):
            # EagerTensors don't allow to use the .name property so we check for a real Tensor
            if type(input) == tf.Tensor:
Julien Plu's avatar
Julien Plu committed
512
513
                # Tensor names have always the pattern `name:id` then we check only the
                # `name` part
Julien Plu's avatar
Julien Plu committed
514
515
516
517
518
                tensor_name = input.name.split(":")[0]

                if tensor_name in parameter_names:
                    output[tensor_name] = input
                else:
Julien Plu's avatar
Julien Plu committed
519
                    output[parameter_names[i]] = input
Julien Plu's avatar
Julien Plu committed
520
521
522
523
            elif isinstance(input, allowed_types) or input is None:
                output[parameter_names[i]] = input
            else:
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
524
525
                    f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
                    f" {parameter_names[i]}."
Julien Plu's avatar
Julien Plu committed
526
                )
527
    elif isinstance(input_ids, Mapping):
Julien Plu's avatar
Julien Plu committed
528
529
        if "inputs" in input_ids:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
530
531
                "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
                " instead.",
Julien Plu's avatar
Julien Plu committed
532
533
534
535
536
537
538
                FutureWarning,
            )

            output["input_ids"] = input_ids.pop("inputs")

        if "decoder_cached_states" in input_ids:
            warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
539
540
                "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
                " `past_key_values` instead.",
Julien Plu's avatar
Julien Plu committed
541
542
543
544
545
                FutureWarning,
            )
            output["past_key_values"] = input_ids.pop("decoder_cached_states")

        for k, v in dict(input_ids).items():
546
            if isinstance(v, allowed_types) or v is None:
Julien Plu's avatar
Julien Plu committed
547
                output[k] = v
548
            elif k not in parameter_names and "args" not in parameter_names:
549
                logger.warning(
550
551
552
553
                    f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
                )
                continue
            else:
Julien Plu's avatar
Julien Plu committed
554
                raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
Julien Plu's avatar
Julien Plu committed
555
    else:
556
        if isinstance(input_ids, (tf.Tensor, KerasTensor)) or input_ids is None:
Julien Plu's avatar
Julien Plu committed
557
558
559
            output[parameter_names[0]] = input_ids
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
560
561
                f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for"
                f" {parameter_names[0]}."
Julien Plu's avatar
Julien Plu committed
562
563
            )

564
    # Populates any unspecified argument with their default value, according to the signature.
Julien Plu's avatar
Julien Plu committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    for name in parameter_names:
        if name not in list(output.keys()) and name != "args":
            output[name] = kwargs.pop(name, signature[name].default)

    # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
    # So to respect the proper output we have to add this exception
    if "args" in output:
        if output["args"] is not None and type(output["args"]) == tf.Tensor:
            tensor_name = output["args"].name.split(":")[0]
            output[tensor_name] = output["args"]
        else:
            # `args` in this case is always the first parameter, then `input_ids`
            output["input_ids"] = output["args"]

        del output["args"]

    if "kwargs" in output:
        del output["kwargs"]

584
585
586
587
588
589
590
591
592
593
594
595
596
    boolean_dict = {
        k: v
        for k, v in output.items()
        if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
    }

    output.update(
        booleans_processing(
            config=config,
            **boolean_dict,
        )
    )

Julien Plu's avatar
Julien Plu committed
597
598
599
    return output


Arthur's avatar
Arthur committed
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
def dtype_byte_size(dtype):
    """
    Returns the size (in bytes) occupied by one parameter of type `dtype`.

    Example:

    ```py
    >>> dtype_byte_size(tf.float32)
    4
    ```
    """
    if dtype == tf.bool:
        return 1 / 8
    bit_search = re.search("[^\d](\d+)$", dtype.name)
    if bit_search is None:
        raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
    bit_size = int(bit_search.groups()[0])
    return bit_size // 8


def tf_shard_checkpoint(weights, max_shard_size="10GB"):
    """
    Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
    given size.

    The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
    optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
    limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
    [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].

    <Tip warning={true}>

    If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
    have a size greater than `max_shard_size`.

    </Tip>

    Args:
        weights (`Dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save.
        max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
            The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
            (like `"5MB"`).
    """
    max_shard_size = convert_file_size_to_int(max_shard_size)

    sharded_state_dicts = []
    current_block = []
    current_block_size = 0
    total_size = 0

    for item in weights:
        weight_size = item.numpy().size * dtype_byte_size(item.dtype)

        # If this weight is going to tip up over the maximal size, we split.
        if current_block_size + weight_size > max_shard_size:
            sharded_state_dicts.append(current_block)
            current_block = []
            current_block_size = 0

        current_block.append(item)
        current_block_size += weight_size
        total_size += weight_size

    # Add the last block
    sharded_state_dicts.append(current_block)

    # If we only have one shard, we return it
    if len(sharded_state_dicts) == 1:
        return {TF2_WEIGHTS_NAME: sharded_state_dicts[0]}, None

    # Otherwise, let's build the index
    weight_map = {}
    shards = {}
    for idx, shard in enumerate(sharded_state_dicts):
        shard_file = TF2_WEIGHTS_NAME.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
        shards[shard_file] = shard
        for weight in shard:
            weight_name = weight.name
            weight_map[weight_name] = shard_file

    # Add the metadata
    metadata = {"total_size": total_size}
    index = {"metadata": metadata, "weight_map": weight_map}
    return shards, index


def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=True):
    """
    This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
    the TF weights from the shard file accordingly to their names and shapes.

    This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
    loaded in the model.

    Args:
        model (`tf.keras.models.Model`): The model in which to load the checkpoint.
        shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
        ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
            Whether or not to ignore the mismatch between the sizes
        strict (`bool`, *optional*, defaults to `True`):
            Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.

    Returns:
        Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
        mismatched layers.
    """

    # Load the index
    missing_keys = []
    unexpected_keys = set()
    saved_keys = set()
    missmatched_keys = set()

    # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
    # the weight, we have to get rid of the first prefix of the name of the layer.
    model_keys = set("/".join(k.name.split("/")[1:]) for k in model.weights)
    model_layer_map = {"/".join(k.name.split("/")[1:]): i for i, k in enumerate(model.weights)}

    for shard_file in shard_files:
        state_dict = tf.io.read_file(shard_file)
        saved_weight_names_set, unexpected_keys_set, missmatched_keys_set = load_tf_shard(
            model, model_layer_map, shard_file, ignore_mismatched_sizes=ignore_mismatched_sizes
        )
        saved_keys.update(saved_weight_names_set)
        unexpected_keys.update(unexpected_keys_set)
        missmatched_keys.update(missmatched_keys_set)
        del state_dict
        gc.collect()

    missing_keys = model_keys - saved_keys
    if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
        error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
        if len(missing_keys) > 0:
            str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
            error_message += f"\nMissing key(s): {str_missing_keys}."
        if len(unexpected_keys) > 0:
            str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
            error_message += f"\nMissing key(s): {str_unexpected_keys}."
        raise RuntimeError(error_message)

    return missing_keys, unexpected_keys, missmatched_keys


def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False):
    """
    Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys.

    Args:
        model (`tf.keras.models.Model`): Model in which the weights are loaded
        model_layer_map (`Dict`): A dictionnary mapping the layer name to the index of the layer in the model.
        resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded
        ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys

    Returns:
        `tf.keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the
        shard file), one for the missmatched layers, and another one for the unexpected layers.
    """
    saved_weight_names_set = set()
    saved_weights = {}
    missmatched_keys = set()
    unexpected_keys = set()
    # Read the H5 file
    try:
        with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
            # Retrieve the name of each layer from the H5 file
            saved_h5_model_layers_name = set(
                hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
            )
            weight_value_tuples = []

            # Compute missing and unexpected sub layers
            # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
            for layer_name in saved_h5_model_layers_name:
                h5_layer_object = sharded_checkpoint_file[layer_name]
                saved_weights[layer_name] = np.asarray(h5_layer_object)

                saved_weight_names_set.add(layer_name)

                if layer_name not in model_layer_map:
                    unexpected_keys.add(layer_name)
                else:
                    symbolic_weight = model.weights[model_layer_map[layer_name]]

                    saved_weight_value = saved_weights[layer_name]
                    # If the current weight is found
                    if saved_weight_value is not None:
                        # Check if the shape of the current weight and the one from the H5 file are different
                        if K.int_shape(symbolic_weight) != saved_weight_value.shape:
                            # If yes we reshape the weight from the H5 file accordingly to the current weight
                            # If the two shapes are not compatible we raise an issue
                            try:
                                array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
                            except ValueError as e:
                                if ignore_mismatched_sizes:
                                    missmatched_keys.add(
                                        (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
                                    )
                                    continue
                                else:
                                    raise e
                        else:
                            array = saved_weight_value

                    # We create the tuple that will be loaded and add it to the final list
                    weight_value_tuples.append((symbolic_weight, array))

        K.batch_set_value(weight_value_tuples)

        return saved_weight_names_set, unexpected_keys, missmatched_keys

    except Exception as e:
        try:
            with open(resolved_archive_file) as f:
                if f.read().startswith("version"):
                    raise OSError(
                        "You seem to have cloned a repository without having git-lfs installed. Please install "
                        "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                        "you cloned."
                    )
                else:
                    raise ValueError(
                        f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained"
                        " model. Make sure you have saved the model properly."
                    ) from e
        except (UnicodeDecodeError, ValueError):
            raise OSError(
                f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' "
                f"at '{resolved_archive_file}'. "
                "If you tried to load a TF model from a sharded checkpoint, you should try converting the model"
                "by loading it in pytorch and saving it localy. A convertion script should be realeased soon."
            )


833
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
Julien Plu's avatar
Julien Plu committed
834
    """
Arthur's avatar
Arthur committed
835
836
    Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
    shapes.
Julien Plu's avatar
Julien Plu committed
837
838

    Args:
839
        model (`tf.keras.models.Model`):
Julien Plu's avatar
Julien Plu committed
840
            The model to load the weights into.
841
        resolved_archive_file (`str`):
Julien Plu's avatar
Julien Plu committed
842
            The location of the H5 file.
843
        ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
844
            Whether or not to ignore weights with shapes that don't match between the checkpoint of the model.
Julien Plu's avatar
Julien Plu committed
845
846

    Returns:
847
848
        Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
        mismatched layers.
Julien Plu's avatar
Julien Plu committed
849
850
851
    """
    missing_layers = []
    unexpected_layers = []
852
    mismatched_layers = []
Julien Plu's avatar
Julien Plu committed
853

Julien Plu's avatar
Julien Plu committed
854
    # Read the H5 file
Arthur's avatar
Arthur committed
855
    with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
Julien Plu's avatar
Julien Plu committed
856
        # Retrieve the name of each layer from the H5 file
Arthur's avatar
Arthur committed
857
858
859
        saved_h5_model_layers_name = set(
            hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
        )
Julien Plu's avatar
Julien Plu committed
860

Julien Plu's avatar
Julien Plu committed
861
862
        # Find the missing layers from the high level list of layers
        missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name)
Julien Plu's avatar
Julien Plu committed
863

Julien Plu's avatar
Julien Plu committed
864
865
866
867
        # Find the unexpected layers from the high level list of layers
        unexpected_layers = list(saved_h5_model_layers_name - set([layer.name for layer in model.layers]))
        saved_weight_names_set = set()
        symbolic_weights_names = set()
Julien Plu's avatar
Julien Plu committed
868
869
        weight_value_tuples = []

Julien Plu's avatar
Julien Plu committed
870
871
        # Compute missing and unexpected sub layers
        # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
Julien Plu's avatar
Julien Plu committed
872
        for layer in model.layers:
Julien Plu's avatar
Julien Plu committed
873
874
875
            # if layer_name from the H5 file belongs to the layers from the instantiated model
            if layer.name in saved_h5_model_layers_name:
                # Get the H5 layer object from its name
Arthur's avatar
Arthur committed
876
                h5_layer_object = sharded_checkpoint_file[layer.name]
Julien Plu's avatar
Julien Plu committed
877
                # Get all the weights as a list from the layer object
Julien Plu's avatar
Julien Plu committed
878
                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
Julien Plu's avatar
Julien Plu committed
879
                saved_weights = {}
Julien Plu's avatar
Julien Plu committed
880

Julien Plu's avatar
Julien Plu committed
881
882
883
884
                # Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
                # And a set with only the names
                for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
                    # TF names always start with the model name so we ignore it
Julien Plu's avatar
Julien Plu committed
885
                    name = "/".join(weight_name.split("/")[1:])
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
886
887
888
889

                    if _prefix is not None:
                        name = _prefix + "/" + name

Julien Plu's avatar
Julien Plu committed
890
                    saved_weights[name] = np.asarray(h5_layer_object[weight_name])
Julien Plu's avatar
Julien Plu committed
891

Julien Plu's avatar
Julien Plu committed
892
893
894
895
                    # Add the updated name to the final list for computing missing/unexpected values
                    saved_weight_names_set.add(name)

                # Loop over each weights from the instantiated model and compare with the weights from the H5 file
Julien Plu's avatar
Julien Plu committed
896
                for symbolic_weight in symbolic_weights:
Julien Plu's avatar
Julien Plu committed
897
                    # TF names always start with the model name so we ignore it
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
898
899
900
901
902
903
904
905
                    if _prefix is not None:
                        delimeter = len(_prefix.split("/"))
                        symbolic_weight_name = "/".join(
                            symbolic_weight.name.split("/")[:delimeter]
                            + symbolic_weight.name.split("/")[delimeter + 1 :]
                        )
                    else:
                        symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
Julien Plu's avatar
Julien Plu committed
906
907
908
909
910

                    # here we check if the current weight is among the weights from the H5 file
                    # If yes, get the weight_value of the corresponding weight from the H5 file
                    # If not, make the value to None
                    saved_weight_value = saved_weights.get(symbolic_weight_name, None)
Julien Plu's avatar
Julien Plu committed
911

Julien Plu's avatar
Julien Plu committed
912
913
                    # Add the updated name to the final list for computing missing/unexpected values
                    symbolic_weights_names.add(symbolic_weight_name)
Julien Plu's avatar
Julien Plu committed
914

Julien Plu's avatar
Julien Plu committed
915
916
917
                    # If the current weight is found
                    if saved_weight_value is not None:
                        # Check if the shape of the current weight and the one from the H5 file are different
Julien Plu's avatar
Julien Plu committed
918
                        if K.int_shape(symbolic_weight) != saved_weight_value.shape:
Julien Plu's avatar
Julien Plu committed
919
920
                            # If yes we reshape the weight from the H5 file accordingly to the current weight
                            # If the two shapes are not compatible we raise an issue
Julien Plu's avatar
Julien Plu committed
921
922
                            try:
                                array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
923
924
925
926
927
928
929
930
                            except ValueError as e:
                                if ignore_mismatched_sizes:
                                    mismatched_layers.append(
                                        (symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
                                    )
                                    continue
                                else:
                                    raise e
Julien Plu's avatar
Julien Plu committed
931
932
933
                        else:
                            array = saved_weight_value

Julien Plu's avatar
Julien Plu committed
934
                        # We create the tuple that will be loaded and add it to the final list
Julien Plu's avatar
Julien Plu committed
935
936
                        weight_value_tuples.append((symbolic_weight, array))

Julien Plu's avatar
Julien Plu committed
937
    # Load all the weights
Julien Plu's avatar
Julien Plu committed
938
939
    K.batch_set_value(weight_value_tuples)

Julien Plu's avatar
Julien Plu committed
940
941
942
943
    # Compute the missing and unexpected layers
    missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
    unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))

944
    return missing_layers, unexpected_layers, mismatched_layers
Julien Plu's avatar
Julien Plu committed
945

Julien Plu's avatar
Julien Plu committed
946

947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
def init_copy_embeddings(old_embeddings, new_num_tokens):
    r"""
    This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case
    new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be
    kept or not. Example:

        - if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4]

            -  mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1]
        - if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5]

            - mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4]
    """
    old_num_tokens, old_embedding_dim = shape_list(old_embeddings)
    size_diff = new_num_tokens - old_num_tokens

    # initialize new embeddings
    # Copy token embeddings from the previous ones
    if tf.math.greater(size_diff, 0):
        # if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size
        # and we create a mask to properly identify the padded values and be replaced by the values of the newly created
        # embeddings
        current_weights = tf.pad(
            old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1
        )
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True)
        mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False)
    else:
        # if the new size if lower than the old one, we take the current embeddings until the new size
        current_weights = tf.slice(
            old_embeddings.value(),
            tf.convert_to_tensor([0, 0]),
            tf.convert_to_tensor([new_num_tokens, old_embedding_dim]),
        )
        mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True)

    return mask, current_weights


Sylvain Gugger's avatar
Sylvain Gugger committed
987
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushToHubMixin):
988
989
    r"""
    Base class for all TF models.
thomwolf's avatar
thomwolf committed
990

Sylvain Gugger's avatar
Sylvain Gugger committed
991
992
    [`TFPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
    downloading and saving models as well as a few methods common to all models to:
thomwolf's avatar
thomwolf committed
993

994
995
        - resize the input embeddings,
        - prune heads in the self-attention heads.
thomwolf's avatar
thomwolf committed
996

997
    Class attributes (overridden by derived classes):
Sylvain Gugger's avatar
Sylvain Gugger committed
998

Sylvain Gugger's avatar
Sylvain Gugger committed
999
1000
1001
1002
1003
1004
        - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
          for this model architecture.
        - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
          classes of the same architecture adding modules on top of the base model.
        - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
          models, `pixel_values` for vision models and `input_values` for speech models).
thomwolf's avatar
thomwolf committed
1005
1006
1007
    """
    config_class = None
    base_model_prefix = ""
1008
    main_input_name = "input_ids"
1009
    _auto_class = None
1010
    _using_dummy_loss = None
1011
    _label_to_output_map = None
1012

1013
1014
1015
1016
1017
1018
    # a list of re pattern of tensor names to ignore from the model when loading the model weights
    # (and avoid unnecessary warnings).
    _keys_to_ignore_on_load_missing = None
    # a list of re pattern of tensor names to ignore from the weights when loading the model weights
    # (and avoid unnecessary warnings).
    _keys_to_ignore_on_load_unexpected = None
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
1019
    _requires_load_weight_prefix = False
thomwolf's avatar
thomwolf committed
1020

1021
    @property
1022
1023
    def dummy_inputs(self) -> Dict[str, tf.Tensor]:
        """
Julien Plu's avatar
Julien Plu committed
1024
1025
1026
        Dummy inputs to build the network.

        Returns:
1027
            `Dict[str, tf.Tensor]`: The dummy inputs.
1028
        """
Julien Plu's avatar
Julien Plu committed
1029
1030
1031
        return {
            "input_ids": tf.constant(DUMMY_INPUTS),
        }
thomwolf's avatar
thomwolf committed
1032

1033
1034
1035
1036
1037
1038
1039
    @property
    def framework(self) -> str:
        """
        :str: Identifies that this is a TensorFlow model.
        """
        return "tf"

thomwolf's avatar
thomwolf committed
1040
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
1041
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
1042
1043
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
1044
1045
1046
                f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
                "`PretrainedConfig`. To create a model from a pretrained model use "
                f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
1047
            )
1048
        # Save config and origin of the pretrained weights if given in model
thomwolf's avatar
thomwolf committed
1049
        self.config = config
1050
        self.name_or_path = config.name_or_path
thomwolf's avatar
thomwolf committed
1051

1052
    def get_config(self):
1053
        return self.config.to_dict()
1054
1055
1056

    @classmethod
    def from_config(cls, config, **kwargs):
1057
1058
1059
        if isinstance(config, PretrainedConfig):
            return cls._from_config(config, **kwargs)
        return cls._from_config(cls.config_class.from_dict(config, **kwargs))
1060

1061
1062
1063
1064
1065
1066
1067
    @classmethod
    def _from_config(cls, config, **kwargs):
        """
        All context managers that the model should be initialized under go here.
        """
        return cls(config, **kwargs)

Julien Plu's avatar
Julien Plu committed
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
    @tf.function(
        input_signature=[
            {
                "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
                "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
                "token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
            }
        ]
    )
    def serving(self, inputs):
        """
        Method used for serving the model.

        Args:
1082
            inputs (`Dict[str, tf.Tensor]`):
1083
                The input of the saved model as a dictionary of tensors.
Julien Plu's avatar
Julien Plu committed
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
        """
        output = self.call(inputs)

        return self.serving_output(output)

    def serving_output(output):
        """
        Prepare the output of the saved model. Each model must implement this function.

        Args:
1094
            output ([`TFBaseModelOutput`]):
Julien Plu's avatar
Julien Plu committed
1095
1096
1097
1098
                The output returned by the model.
        """
        raise NotImplementedError

1099
    def get_input_embeddings(self) -> tf.keras.layers.Layer:
1100
        """
1101
        Returns the model's input embeddings layer.
1102
1103

        Returns:
1104
            `tf.Variable`: The embeddings layer mapping vocabulary to hidden states.
1105
        """
1106
        main_layer = getattr(self, self.base_model_prefix, self)
Julien Plu's avatar
Julien Plu committed
1107

1108
1109
        if main_layer is not self:
            return main_layer.get_input_embeddings()
1110
1111
1112
        else:
            raise NotImplementedError

1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
    def _save_checkpoint(self, checkpoint_dir, epoch):
        if not os.path.isdir(checkpoint_dir):
            os.mkdir(checkpoint_dir)
        # We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer
        # state for us, because it requires special handling for objects like custom losses, which we use
        # internally and which users are likely to use too
        weights_path = os.path.join(checkpoint_dir, "weights.h5")
        self.save_weights(weights_path)
        extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()}
        extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle")
        with open(extra_data_path, "wb") as f:
            pickle.dump(extra_data, f)

    def load_repo_checkpoint(self, repo_path_or_name):
        """
        Loads a saved checkpoint (model weights and optimizer state) from a repo. Returns the current epoch count when
        the checkpoint was made.

        Args:
1132
            repo_path_or_name (`str`):
1133
1134
1135
1136
                Can either be a repository name for your {object} in the Hub or a path to a local folder (in which case
                the repository will have the name of that local folder).

        Returns:
1137
            `dict`: A dictionary of extra metadata from the checkpoint, most commonly an "epoch" count.
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
        """
        if getattr(self, "optimizer", None) is None:
            raise RuntimeError(
                "Checkpoint loading failed as no optimizer is attached to the model. "
                "This is most likely caused by the model not being compiled."
            )
        if not os.path.isdir(repo_path_or_name):
            # If this isn't a local path, check that the remote repo exists and has a checkpoint in it
            repo_files = list_repo_files(repo_path_or_name)
            for file in ("checkpoint/weights.h5", "checkpoint/extra_data.pickle"):
                if file not in repo_files:
                    raise FileNotFoundError(f"Repo {repo_path_or_name} does not contain checkpoint file {file}!")
            if "/" not in repo_path_or_name:
                model_id = repo_path_or_name
                repo_path_or_name = self.get_full_repo_name(repo_path_or_name)
            else:
                model_id = repo_path_or_name.split("/")[-1]
            repo = Repository(model_id, clone_from=f"https://huggingface.co/{repo_path_or_name}")
            local_dir = repo.local_dir
        else:
            local_dir = repo_path_or_name

        # Now make sure the repo actually has a checkpoint in it.
        checkpoint_dir = os.path.join(local_dir, "checkpoint")
        weights_file = os.path.join(checkpoint_dir, "weights.h5")
        if not os.path.isfile(weights_file):
            raise FileNotFoundError(f"Could not find checkpoint file weights.h5 in repo {repo_path_or_name}!")
        extra_data_file = os.path.join(checkpoint_dir, "extra_data.pickle")
        if not os.path.isfile(extra_data_file):
            raise FileNotFoundError(f"Could not find checkpoint file extra_data.pickle in repo {repo_path_or_name}!")

        # Assuming the repo is real and we got a checkpoint, load the weights and the optimizer state into the model.
        # The optimizer state includes the iteration count, so learning rate schedules should resume as normal too.
        self.load_weights(weights_file)
        with open(extra_data_file, "rb") as f:
            extra_data = pickle.load(f)
        self.optimizer.set_weights(extra_data["optimizer_state"])

        # Finally, return the epoch number from the checkpoint. This isn't a property of the model, so we can't
        # set it directly, but the user can pass it to fit().
        return {"epoch": extra_data["epoch"]}

1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
    def prepare_tf_dataset(
        self,
        dataset: "datasets.Dataset",  # noqa:F821
        batch_size: int = 8,
        shuffle: bool = True,
        tokenizer: Optional["PreTrainedTokenizerBase"] = None,
        collate_fn: Optional[Callable] = None,
        collate_fn_args: Optional[Dict[str, Any]] = None,
        drop_remainder: Optional[bool] = None,
        prefetch: bool = True,
    ):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1192
        Wraps a HuggingFace [`~datasets.Dataset`] as a `tf.data.Dataset` with collation and batching. This method is
1193
1194
1195
1196
1197
1198
1199
        designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without
        further modification. The method will drop columns from the dataset if they don't match input names for the
        model. If you want to specify the column names to return rather than using the names that match this model, we
        recommend using `Dataset.to_tf_dataset()` instead.

        Args:
            dataset (`Any`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1200
                A [~`datasets.Dataset`] to be wrapped as a `tf.data.Dataset`.
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
            batch_size (`int`, defaults to 8):
                The size of batches to return.
            shuffle (`bool`, defaults to `True`):
                Whether to return samples from the dataset in random order. Usually `True` for training datasets and
                `False` for validation/test datasets.
            tokenizer ([`PreTrainedTokenizerBase`], *optional*):
                A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific
                `collate_fn` is passed instead.
            collate_fn (`Callable`, *optional*):
                A function that collates samples from the dataset into a single batch. Defaults to
                `DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is
                passed.
            collate_fn_args (`Dict[str, Any]`, *optional*):
                A dict of arguments to pass to the `collate_fn` alongside the list of samples.
            drop_remainder (`bool`, *optional*):
                Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults
                to the same setting as `shuffle`.
            prefetch (`bool`, defaults to `True`):
                Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for
                performance, but can be disabled in edge cases.


        Returns:
            `Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API.
        """
        requires_backends(self, ["datasets"])
        import datasets

        if collate_fn is None:
            if tokenizer is None:
                collate_fn = DefaultDataCollator(return_tensors="tf")
            else:
                collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
        if collate_fn_args is None:
            collate_fn_args = dict()

        if not isinstance(dataset, datasets.Dataset):
            raise TypeError("Dataset argument should be a datasets.Dataset!")
        model_inputs = list(dict(inspect.signature(self.call).parameters).keys())
        model_labels = find_labels(self.__class__)
        unwanted_columns = [
            feature
            for feature in dataset.features
            if feature not in model_inputs and feature not in ("label_ids", "label")
        ]
        dataset = dataset.remove_columns(unwanted_columns)
        output_signature, _ = dataset._get_output_signature(
            dataset,
            batch_size=None,
            collate_fn=collate_fn,
            collate_fn_args=collate_fn_args,
        )
        output_columns = list(output_signature.keys())
        feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
        label_cols = [col for col in output_columns if col in model_labels]
1256
1257
1258

        if drop_remainder is None:
            drop_remainder = shuffle
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
        tf_dataset = dataset.to_tf_dataset(
            columns=feature_cols,
            label_cols=label_cols,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_remainder=drop_remainder,
            collate_fn=collate_fn,
            collate_fn_args=collate_fn_args,
            prefetch=prefetch,
        )
        return tf_dataset

Matt's avatar
Matt committed
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
    def compile(
        self,
        optimizer="rmsprop",
        loss="passthrough",
        metrics=None,
        loss_weights=None,
        weighted_metrics=None,
        run_eagerly=None,
        steps_per_execution=None,
        **kwargs
    ):
        """
        This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
        function themselves.
        """
        if loss == "passthrough":
            logger.warning(
                "No loss specified in compile() - the model's internal loss computation will be used as the "
                "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
1290
                "To disable this behaviour please pass a loss argument, or explicitly pass "
1291
1292
1293
1294
1295
1296
1297
                "`loss=None` if you do not want your model to compute a loss."
            )
            loss = dummy_loss
            self._using_dummy_loss = True
        else:
            self._using_dummy_loss = False
        parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys())
1298
        # This argument got renamed, we need to support both versions
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
        if "steps_per_execution" in parent_args:
            super().compile(
                optimizer=optimizer,
                loss=loss,
                metrics=metrics,
                loss_weights=loss_weights,
                weighted_metrics=weighted_metrics,
                run_eagerly=run_eagerly,
                steps_per_execution=steps_per_execution,
                **kwargs,
            )
        else:
            super().compile(
                optimizer=optimizer,
                loss=loss,
                metrics=metrics,
                loss_weights=loss_weights,
                weighted_metrics=weighted_metrics,
                run_eagerly=run_eagerly,
                experimental_steps_per_execution=steps_per_execution,
                **kwargs,
Matt's avatar
Matt committed
1320
1321
            )

1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
    def compute_loss(self, *args, **kwargs):
        if hasattr(tf.keras.Model, "compute_loss"):
            # This will be true in TF 2.8 or greater
            return super().compute_loss(*args, **kwargs)
        else:
            warnings.warn(
                "The old compute_loss method is deprecated as it conflicts with the Keras compute_loss "
                "method added in TF 2.8. If you want the original HF compute_loss, please call "
                "hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, "
                "calling compute_loss() will get the Keras method instead.",
                FutureWarning,
            )
            return self.hf_compute_loss(*args, **kwargs)

1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
    def get_label_to_output_name_mapping(self):
        arg_names = list(dict(inspect.signature(self.call).parameters).keys())
        if self._label_to_output_map is not None:
            return self._label_to_output_map
        elif "start_positions" in arg_names:
            return {"start_positions": "start_logits", "end_positions": "end_logits"}
        elif "sentence_order_label" in arg_names:
            return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
        elif "next_sentence_label" in arg_names:
            return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
        elif "mc_labels" in arg_names:
            return {"labels": "logits", "mc_labels": "mc_logits"}
        else:
            return dict()

Matt's avatar
Matt committed
1351
1352
    def train_step(self, data):
        """
1353
1354
1355
1356
        A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
        and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
        labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
        that they are available to the model during the forward pass.
Matt's avatar
Matt committed
1357
        """
1358

1359
1360
1361
1362
1363
        # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
        arg_names = list(dict(inspect.signature(self.call).parameters).keys())
        label_kwargs = find_labels(self.__class__)
        label_to_output = self.get_label_to_output_name_mapping()
        output_to_label = {val: key for key, val in label_to_output.items()}
1364
1365
        if not self._using_dummy_loss:
            data = data_adapter.expand_1d(data)
Matt's avatar
Matt committed
1366
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
Matt's avatar
Matt committed
1367
1368
1369
1370
1371
1372
1373
        # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
        # them during input/label pre-processing. This avoids surprising the user by wrecking their data.
        # In addition, modifying mutable Python inputs makes XLA compilation impossible.
        if isinstance(x, dict):
            x = x.copy()
        if isinstance(y, dict):
            y = y.copy()
1374
1375
1376
1377

        # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
        # if those keys are not already present in the input dict
        if self._using_dummy_loss and y is not None:
1378

1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
            # If y is a tensor and the model only has one label-like input, map y to that input
            if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
                if isinstance(x, tf.Tensor):
                    x = {arg_names[0]: x}
                label_kwarg = next(iter(label_kwargs))
                if label_kwarg not in x:
                    x[label_kwarg] = y
            # Otherwise, copy keys from y to x as long as they weren't already present in x
            elif isinstance(y, dict):
                if isinstance(x, tf.Tensor):
                    x = {arg_names[0]: x}
                for key, val in y.items():
                    if key in arg_names and key not in x:
                        x[key] = val
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
                    elif output_to_label.get(key, None) in arg_names and key not in x:
                        x[output_to_label[key]] = val
        if y is None:
            y = {key: val for key, val in x.items() if key in label_kwargs}
            if not y and not self._using_dummy_loss:
                raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")

        if isinstance(y, dict):
            # Rename labels at this point to match output heads
            y = {label_to_output.get(key, key): val for key, val in y.items()}
1403

Matt's avatar
Matt committed
1404
1405
1406
        # Run forward pass.
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
1407
1408
1409
            if self._using_dummy_loss:
                loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
            else:
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
                loss = None

            # This next block matches outputs to label keys. Tensorflow's standard method for doing this
            # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
            if isinstance(y, dict) and len(y) == 1:
                if list(y.keys())[0] in y_pred.keys():
                    y_pred = y_pred[list(y.keys())[0]]
                elif list(y_pred.keys())[0] == "loss":
                    y_pred = y_pred[1]
                else:
                    y_pred = y_pred[0]
                _, y = y.popitem()
            elif isinstance(y, dict):
                # If the labels are a dict, match keys from the output by name
                y_pred = {key: val for key, val in y_pred.items() if key in y}
            elif isinstance(y, tuple) or isinstance(y, list):
                # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
                if list(y_pred.keys())[0] == "loss":
                    y_pred = y_pred.to_tuple()[1:]
                else:
                    y_pred = y_pred.to_tuple()
                y_pred = y_pred[: len(y)]  # Remove unused fields in case those cause problems
            else:
                # If the labels are a single tensor, match them to the first non-loss tensor in the output
                if list(y_pred.keys())[0] == "loss":
                    y_pred = y_pred[1]
                else:
                    y_pred = y_pred[0]

            if loss is None:
1440
                loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
1441

Matt's avatar
Matt committed
1442
1443
        # Run backwards pass.
        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
1444

1445
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
Matt's avatar
Matt committed
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
        # Collect metrics to return
        return_metrics = {}
        for metric in self.metrics:
            result = metric.result()
            if isinstance(result, dict):
                return_metrics.update(result)
            else:
                return_metrics[metric.name] = result
        return return_metrics

    def test_step(self, data):
        """
1458
1459
1460
1461
        A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
        and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
        labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
        that they are available to the model during the forward pass.
Matt's avatar
Matt committed
1462
        """
1463
1464
1465
1466
1467
        # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
        arg_names = list(dict(inspect.signature(self.call).parameters).keys())
        label_kwargs = find_labels(self.__class__)
        label_to_output = self.get_label_to_output_name_mapping()
        output_to_label = {val: key for key, val in label_to_output.items()}
1468
1469
        if not self._using_dummy_loss:
            data = data_adapter.expand_1d(data)
Matt's avatar
Matt committed
1470
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
Matt's avatar
Matt committed
1471
1472
1473
1474
1475
1476
1477
        # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
        # them during input/label pre-processing. This avoids surprising the user by wrecking their data.
        # In addition, modifying mutable Python inputs makes XLA compilation impossible.
        if isinstance(x, dict):
            x = x.copy()
        if isinstance(y, dict):
            y = y.copy()
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496

        # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
        # if those keys are not already present in the input dict
        if self._using_dummy_loss and y is not None:
            arg_names = list(dict(inspect.signature(self.call).parameters).keys())
            # If y is a tensor and the model only has one label-like input, map y to that input
            if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
                if isinstance(x, tf.Tensor):
                    x = {arg_names[0]: x}
                label_kwarg = next(iter(label_kwargs))
                if label_kwarg not in x:
                    x[label_kwarg] = y
            # Otherwise, copy keys from y to x as long as they weren't already present in x
            elif isinstance(y, dict):
                if isinstance(x, tf.Tensor):
                    x = {arg_names[0]: x}
                for key, val in y.items():
                    if key in arg_names and key not in x:
                        x[key] = val
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
                    elif output_to_label.get(key, None) in arg_names and key not in x:
                        x[output_to_label[key]] = val
        if y is None:
            y = {key: val for key, val in x.items() if key in label_kwargs}
            if not y and not self._using_dummy_loss:
                raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")

        if isinstance(y, dict):
            # Rename labels at this point to match output heads
            y = {label_to_output.get(key, key): val for key, val in y.items()}
1507
1508

        # Run forward pass.
Matt's avatar
Matt committed
1509
        y_pred = self(x, training=False)
1510
        if self._using_dummy_loss:
1511
            loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
1512
        else:
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
            loss = None

        # This next block matches outputs to label keys. Tensorflow's standard method for doing this
        # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
        if isinstance(y, dict) and len(y) == 1:
            if list(y.keys())[0] in y_pred.keys():
                y_pred = y_pred[list(y.keys())[0]]
            elif list(y_pred.keys())[0] == "loss":
                y_pred = y_pred[1]
            else:
                y_pred = y_pred[0]
            _, y = y.popitem()
        elif isinstance(y, dict):
            # If the labels are a dict, match keys from the output by name
            y_pred = {key: val for key, val in y_pred.items() if key in y}
        elif isinstance(y, tuple) or isinstance(y, list):
            # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
            if list(y_pred.keys())[0] == "loss":
                y_pred = y_pred.to_tuple()[1:]
            else:
                y_pred = y_pred.to_tuple()
            y_pred = y_pred[: len(y)]  # Remove unused fields in case those cause problems
1535
        else:
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
            # If the labels are a single tensor, match them to the first non-loss tensor in the output
            if list(y_pred.keys())[0] == "loss":
                y_pred = y_pred[1]
            else:
                y_pred = y_pred[0]

        if loss is None:
            loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)

        self.compiled_metrics.update_state(y, y_pred, sample_weight)
Matt's avatar
Matt committed
1546
        # Collect metrics to return
1547
        return_metrics = {}
Matt's avatar
Matt committed
1548
1549
1550
1551
1552
1553
1554
1555
        for metric in self.metrics:
            result = metric.result()
            if isinstance(result, dict):
                return_metrics.update(result)
            else:
                return_metrics[metric.name] = result
        return return_metrics

Matt's avatar
Matt committed
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
    def create_model_card(
        self,
        output_dir,
        model_name: str,
        language: Optional[str] = None,
        license: Optional[str] = None,
        tags: Optional[str] = None,
        finetuned_from: Optional[str] = None,
        tasks: Optional[str] = None,
        dataset_tags: Optional[Union[str, List[str]]] = None,
        dataset: Optional[Union[str, List[str]]] = None,
        dataset_args: Optional[Union[str, List[str]]] = None,
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
1569
        # Avoids a circular import by doing this when necessary.
1570
        from .modelcard import TrainingSummary  # tests_ignore
Sylvain Gugger's avatar
Sylvain Gugger committed
1571

Matt's avatar
Matt committed
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
        training_summary = TrainingSummary.from_keras(
            self,
            keras_history=self.history,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
            tasks=tasks,
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(output_dir, "README.md"), "w") as f:
            f.write(model_card)

1589
1590
    def set_input_embeddings(self, value):
        """
1591
        Set model's input embeddings
1592
1593

        Args:
1594
            value (`tf.Variable`):
1595
                The new weights mapping hidden states to vocabulary.
1596
        """
1597
        main_layer = getattr(self, self.base_model_prefix)
1598

1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
        if main_layer is None:
            raise NotImplementedError("The model does not implements the base_model_prefix attribute.")

        try:
            main_layer.set_input_embeddings(value)
        except AttributeError:
            logger.info("Building the model")
            self(self.dummy_inputs)
            main_layer.set_input_embeddings(value)

    def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
1610
        """
1611
        Returns the model's output embeddings
1612
1613

        Returns:
1614
            `tf.Variable`: The new weights mapping vocabulary to hidden states.
1615
        """
1616
1617
1618
        if self.get_lm_head() is not None:
            lm_head = self.get_lm_head()

1619
1620
1621
1622
1623
1624
1625
            try:
                return lm_head.get_output_embeddings()
            except AttributeError:
                logger.info("Building the model")
                self(self.dummy_inputs)

                return lm_head().get_output_embeddings()
1626

1627
1628
        return None  # Overwrite for models with output embeddings

1629
1630
1631
1632
1633
    def set_output_embeddings(self, value):
        """
        Set model's output embeddings

        Args:
1634
            value (`tf.Variable`):
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
                The new weights mapping hidden states to vocabulary.
        """
        if self.get_lm_head() is not None:
            lm_head = self.get_lm_head()
            try:
                lm_head.set_output_embeddings(value)
            except AttributeError:
                logger.info("Building the model")
                self(self.dummy_inputs)
                lm_head.set_output_embeddings(value)

1646
1647
1648
    def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
        """
        Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the
1649
        embeddings
1650
1651

        Return:
1652
            `tf.keras.layers.Layer`: The layer that handles the bias, None if not an LM model.
1653
        """
1654
1655
1656
1657
        warnings.warn(
            "The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning
        )
        return self.get_lm_head()
1658
1659
1660

    def get_prefix_bias_name(self) -> Union[None, str]:
        """
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
1661
        Get the concatenated _prefix name of the bias from the model name to the parent layer
1662
1663

        Return:
1664
            `str`: The _prefix name of the bias.
1665
        """
1666
1667
1668
1669
1670
1671
1672
1673
        warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
        return None

    def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
        """
        Dict of bias attached to an LM head. The key represents the name of the bias attribute.

        Return:
1674
            `tf.Variable`: The weights representing the bias, None if not an LM model.
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
        """
        if self.get_lm_head() is not None:
            lm_head = self.get_lm_head()
            try:
                return lm_head.get_bias()
            except AttributeError:
                self(self.dummy_inputs)

                return lm_head.get_bias()
        return None

    def set_bias(self, value):
        """
        Set all the bias in the LM head.

        Args:
1691
            value (`Dict[tf.Variable]`):
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
                All the new bias attached to an LM head.
        """
        if self.get_lm_head() is not None:
            lm_head = self.get_lm_head()
            try:
                lm_head.set_bias(value)
            except AttributeError:
                self(self.dummy_inputs)
                lm_head.set_bias(value)

    def get_lm_head(self) -> tf.keras.layers.Layer:
        """
        The LM Head layer. This method must be overwritten by all the models that have a lm head.

        Return:
1707
            `tf.keras.layers.Layer`: The LM head layer if the model has one, None if not.
1708
        """
1709
1710
        return None

1711
1712
    def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
        """
1713
        Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
1714

1715
        Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
1716

1717
        Arguments:
1718
            new_num_tokens (`int`, *optional*):
1719
                The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
Sylvain Gugger's avatar
Sylvain Gugger committed
1720
1721
                vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
                returns a pointer to the input tokens `tf.Variable` module of the model without doing anything.
1722
1723

        Return:
1724
            `tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
1725
        """
1726
1727
        if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
            return self._get_word_embedding_weight(self.get_input_embeddings())
1728

1729
        model_embeds = self._resize_token_embeddings(new_num_tokens)
1730
1731
1732

        # Update base model and current model config
        self.config.vocab_size = new_num_tokens
1733
1734
1735

        return model_embeds

1736
    def _get_word_embedding_weight(model, embedding_layer):
Joao Gante's avatar
Joao Gante committed
1737
1738
1739
1740
1741
        # If the variable holds the weights themselves, return them
        if isinstance(embedding_layer, tf.Tensor):
            return embedding_layer
        # Otherwise, try to get them from the layer's attributes

1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
        embeds = getattr(embedding_layer, "weight", None)
        if embeds is not None:
            return embeds

        embeds = getattr(embedding_layer, "decoder", None)
        if embeds is not None:
            return embeds

        # The reason why the attributes don't exist might be
        # because the model is not built, so retry getting
        # the argument after building the model
        model(model.dummy_inputs)

        embeds = getattr(embedding_layer, "weight", None)
        if embeds is not None:
            return embeds

        embeds = getattr(embedding_layer, "decoder", None)
        if embeds is not None:
            return embeds

        return None
1764

1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
    def _resize_token_embeddings(self, new_num_tokens):
        old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)

        # if word embeddings are not tied, make sure that lm head bias is resized as well
        if self.get_bias() is not None:
            old_lm_head_bias = self.get_bias()
            new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)

            self.set_bias(new_lm_head_bias)

        # if word embeddings are not tied, make sure that lm head decoder is resized as well
        if self.get_output_embeddings() is not None:
            old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
            new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)

            self.set_output_embeddings(new_lm_head_decoder)

        self.set_input_embeddings(new_embeddings)

        return self.get_input_embeddings()

    def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens):
1788
        """
1789
1790
        Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
        Reducing the size will remove vectors from the end
thomwolf's avatar
thomwolf committed
1791
1792

        Args:
1793
            old_lm_head_bias (`tf.Variable`):
1794
                Old lm head bias to be resized.
1795
            new_num_tokens (`int`, *optional*):
1796
                New number of tokens in the linear matrix.
1797
1798

                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1799
                vectors from the end. If not provided or `None`, just returns None
1800
1801

        Return:
1802
            `tf.Variable`: Pointer to the resized bias.
thomwolf's avatar
thomwolf committed
1803
        """
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
        new_lm_head_bias = {}

        for attr, weight in old_lm_head_bias.items():
            first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
            size_diff = new_num_tokens - old_num_tokens
            final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens]

            # initialize new bias
            if tf.math.greater(size_diff, 0):
                padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
                current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1)
                num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
                mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy]
                bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True)
                bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False)
            else:
                slice_from = [0] if first_dim is None else [0, 0]
                current_bias = tf.slice(
                    weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape)
                )
                bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True)
1825

1826
1827
1828
1829
1830
1831
1832
            new_bias = self.add_weight(
                shape=final_shape,
                initializer="zeros",
                trainable=True,
                name=weight.name.split(":")[0],
            )
            init_bias = tf.where(bias_mask, current_bias, new_bias.value())
1833

1834
1835
            new_bias.assign(init_bias)
            new_lm_head_bias[attr] = new_bias
1836

1837
        return new_lm_head_bias
thomwolf's avatar
thomwolf committed
1838

1839
1840
1841
1842
    def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens):
        """
        Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end.
        Reducing the size will remove vectors from the end
thomwolf's avatar
thomwolf committed
1843

1844
        Args:
1845
            old_lm_head_decoder (`tf.Variable`):
1846
                Old lm head decoder to be resized.
1847
            new_num_tokens (`int`, *optional*):
1848
                New number of tokens in the linear matrix.
1849

1850
                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1851
                vectors from the end. If not provided or `None`, just returns None
1852

1853
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1854
1855
            `tf.Variable`: Pointer to the resized decoder or None if the output embeddings are different from the input
            ones.
1856
1857
1858
1859
1860
        """
        new_lm_head_decoder = old_lm_head_decoder
        is_input_output_equals = tf.reduce_any(
            self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder
        )
1861

1862
1863
1864
1865
1866
        if old_lm_head_decoder is not None and not is_input_output_equals:
            old_embedding_dim = shape_list(old_lm_head_decoder)[1]
            decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens)
            new_lm_head_decoder = self.add_weight(
                shape=(new_num_tokens, old_embedding_dim),
1867
1868
                initializer="zeros",
                trainable=True,
1869
                name=old_lm_head_decoder.name.split(":")[0],
1870
            )
1871
1872
1873
            init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value())

            new_lm_head_decoder.assign(init_decoder)
1874

1875
        return new_lm_head_decoder
1876

1877
1878
1879
1880
    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
        """
        Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly
        initialized vectors at the end. Reducing the size will remove vectors from the end
1881

1882
        Args:
1883
            old_embeddings (`tf.Variable`):
1884
                Old embeddings to be resized.
1885
            new_num_tokens (`int`, *optional*):
1886
                New number of tokens in the embedding matrix.
1887

1888
                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1889
1890
                vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
                ``tf.Variable``` module of the model without doing anything.
1891

1892
        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1893
1894
            `tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is
            `None`
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
        """
        old_embedding_dim = shape_list(old_embeddings)[1]
        init_range = getattr(self.config, "initializer_range", 0.02)
        embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens)
        new_embeddings = self.add_weight(
            name=old_embeddings.name.split(":")[0],
            shape=[new_num_tokens, old_embedding_dim],
            initializer=get_initializer(init_range),
            dtype=tf.float32,
        )
        init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value())
1906

1907
        new_embeddings.assign(init_embeddings)
1908

1909
        return new_embeddings
thomwolf's avatar
thomwolf committed
1910
1911

    def prune_heads(self, heads_to_prune):
1912
1913
        """
        Prunes heads of the base model.
thomwolf's avatar
thomwolf committed
1914

1915
        Arguments:
1916
            heads_to_prune (`Dict[int, List[int]]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1917
1918
1919
                Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
                to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
                layer 1 and heads 2 and 3 on layer 2.
thomwolf's avatar
thomwolf committed
1920
1921
1922
        """
        raise NotImplementedError

Arthur's avatar
Arthur committed
1923
1924
1925
1926
1927
1928
1929
1930
1931
    def save_pretrained(
        self,
        save_directory,
        saved_model=False,
        version=1,
        push_to_hub=False,
        max_shard_size: Union[int, str] = "10GB",
        **kwargs
    ):
1932
1933
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
1934
        [`~TFPreTrainedModel.from_pretrained`] class method.
1935
1936

        Arguments:
1937
            save_directory (`str`):
1938
                Directory to which to save. Will be created if it doesn't exist.
1939
            saved_model (`bool`, *optional*, defaults to `False`):
Julien Plu's avatar
Julien Plu committed
1940
                If the model has to be saved in saved model format as well or not.
1941
            version (`int`, *optional*, defaults to 1):
Julien Plu's avatar
Julien Plu committed
1942
1943
1944
                The version of the saved model. A saved model needs to be versioned in order to be properly loaded by
                TensorFlow Serving as detailed in the official documentation
                https://www.tensorflow.org/tfx/serving/serving_basic
1945
            push_to_hub (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1946
                Whether or not to push your model to the Hugging Face model hub after saving it.
1947

1948
                <Tip warning={true}>
1949

Sylvain Gugger's avatar
Sylvain Gugger committed
1950
1951
1952
                Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
                which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
                folder. Pass along `temp_dir=True` to use a temporary directory instead.
1953
1954

                </Tip>
1955

Arthur's avatar
Arthur committed
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
            max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
                The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
                lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).

                <Tip warning={true}>

                If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
                which will be bigger than `max_shard_size`.

                </Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
1967
            kwargs:
1968
                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
thomwolf's avatar
thomwolf committed
1969
        """
1970
        if os.path.isfile(save_directory):
1971
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
1972
            return
1973
1974
1975
1976
1977

        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            repo = self._create_or_get_repo(save_directory, **kwargs)

1978
        os.makedirs(save_directory, exist_ok=True)
thomwolf's avatar
thomwolf committed
1979

Julien Plu's avatar
Julien Plu committed
1980
1981
1982
1983
1984
        if saved_model:
            saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
            self.save(saved_model_dir, include_optimizer=False, signatures=self.serving)
            logger.info(f"Saved model created in {saved_model_dir}")

thomwolf's avatar
thomwolf committed
1985
        # Save configuration file
1986
        self.config.architectures = [self.__class__.__name__[2:]]
1987
1988
1989
1990
1991
1992

        # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
        # loaded from the Hub.
        if self._auto_class is not None:
            custom_object_save(self, save_directory, config=self.config)

thomwolf's avatar
thomwolf committed
1993
1994
1995
1996
        self.config.save_pretrained(save_directory)

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
Arthur's avatar
Arthur committed
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038

        shards, index = tf_shard_checkpoint(self.weights, max_shard_size)

        # Clean the folder from a previous save
        for filename in os.listdir(save_directory):
            full_filename = os.path.join(save_directory, filename)
            # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
            # in distributed settings to avoid race conditions.
            if (
                filename.startswith(TF2_WEIGHTS_NAME[:-4])
                and os.path.isfile(full_filename)
                and filename not in shards.keys()
            ):
                os.remove(full_filename)

        if index is None:
            self.save_weights(output_model_file)
            logger.info(f"Model weights saved in {output_model_file}")
        else:
            save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME)
            # Save the index as well
            with open(save_index_file, "w", encoding="utf-8") as index_file:
                content = json.dumps(index, indent=2, sort_keys=True) + "\n"
                index_file.write(content)
            logger.info(
                f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
                f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
                f"index located at {save_index_file}."
            )
            for shard_file, shard in shards.items():
                with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
                    save_attributes_to_hdf5_group(
                        shard_file,
                        "layer_names",
                        ["/".join(layer.name.split("/")[1:]).encode("utf8") for layer in shard],
                    )

                    for layer in sorted(shard, key=lambda x: x.name):
                        param_dset = shard_file.create_dataset(
                            "/".join(layer.name.split("/")[1:]), layer.numpy().shape, dtype=layer.numpy().dtype
                        )
                        param_dset[:] = layer.numpy()
thomwolf's avatar
thomwolf committed
2039

Sylvain Gugger's avatar
Sylvain Gugger committed
2040
        if push_to_hub:
2041
            url = self._push_to_hub(repo, commit_message=commit_message)
Sylvain Gugger's avatar
Sylvain Gugger committed
2042
2043
            logger.info(f"Model pushed to the hub in this commit: {url}")

thomwolf's avatar
thomwolf committed
2044
2045
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
2046
2047
        r"""
        Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
thomwolf's avatar
thomwolf committed
2048

2049
        The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
2050
2051
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.
thomwolf's avatar
thomwolf committed
2052

2053
        The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
2054
        weights are discarded.
thomwolf's avatar
thomwolf committed
2055
2056

        Parameters:
2057
            pretrained_model_name_or_path (`str`, *optional*):
2058
2059
                Can be either:

2060
                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Sylvain Gugger's avatar
Sylvain Gugger committed
2061
2062
                      Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
                      user or organization name, like `dbmdz/bert-base-german-cased`.
2063
2064
                    - A path to a *directory* containing model weights saved using
                      [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
Sylvain Gugger's avatar
Sylvain Gugger committed
2065
2066
2067
2068
                    - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
                      case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
                      argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
                      using the provided conversion scripts and loading the TensorFlow model afterwards.
2069
2070
2071
2072
2073
                    - `None` if you are both providing the configuration and state dictionary (resp. with keyword
                      arguments `config` and `state_dict`).
            model_args (sequence of positional arguments, *optional*):
                All remaining positional arguments will be passed to the underlying model's `__init__` method.
            config (`Union[PretrainedConfig, str]`, *optional*):
2074
2075
                Can be either:

2076
2077
                    - an instance of a class derived from [`PretrainedConfig`],
                    - a string valid as input to [`~PretrainedConfig.from_pretrained`].
2078

2079
                Configuration for the model to use instead of an automatically loaded configuration. Configuration can
2080
2081
                be automatically loaded when:

2082
                    - The model is a model provided by the library (loaded with the *model id* string of a pretrained
2083
                      model).
Sylvain Gugger's avatar
Sylvain Gugger committed
2084
2085
                    - The model was saved using [`~TFPreTrainedModel.save_pretrained`] and is reloaded by supplying the
                      save directory.
2086
2087
2088
                    - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
                      configuration JSON file named *config.json* is found in the directory.
            from_pt: (`bool`, *optional*, defaults to `False`):
2089
                Load the model weights from a PyTorch state_dict save file (see docstring of
2090
2091
                `pretrained_model_name_or_path` argument).
            ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
2092
2093
2094
                Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
                as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
                checkpoint with 3 labels).
2095
            cache_dir (`str`, *optional*):
2096
2097
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
2098
            force_download (`bool`, *optional*, defaults to `False`):
2099
2100
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
2101
            resume_download (`bool`, *optional*, defaults to `False`):
2102
2103
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
Sylvain Gugger's avatar
Sylvain Gugger committed
2104
2105
2106
2107
2108
            proxies:
                (`Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g.,
                `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
                output_loading_info(`bool`, *optional*, defaults to `False`): Whether ot not to also return a
                dictionary containing missing keys, unexpected keys and error messages.
2109
            local_files_only(`bool`, *optional*, defaults to `False`):
2110
                Whether or not to only look at local files (e.g., not try doanloading the model).
2111
            use_auth_token (`str` or *bool*, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2112
2113
                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
                when running `transformers-cli login` (stored in `~/.huggingface`).
2114
            revision (`str`, *optional*, defaults to `"main"`):
Julien Chaumond's avatar
Julien Chaumond committed
2115
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
2116
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
Julien Chaumond's avatar
Julien Chaumond committed
2117
                identifier allowed by git.
2118
            mirror (`str`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2119
2120
2121
                Mirror source to accelerate downloads in China. If you are from China and have an accessibility
                problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
                Please refer to the mirror site for more information.
2122
            kwargs (remaining dictionary of keyword arguments, *optional*):
2123
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
2124
                `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
2125
2126
                automatically loaded:

2127
2128
                    - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
                      underlying model's `__init__` method (we assume all relevant updates to the configuration have
2129
                      already been done)
2130
                    - If a configuration is not provided, `kwargs` will be first passed to the configuration class
Sylvain Gugger's avatar
Sylvain Gugger committed
2131
2132
2133
2134
                      initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
                      corresponds to a configuration attribute will be used to override said attribute with the
                      supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
                      will be passed to the underlying model's `__init__` function.
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145

        <Tip>

        Passing `use_auth_token=True` is required when you want to use a private model.

        </Tip>

        Examples:

        ```python
        >>> from transformers import BertConfig, TFBertModel
Sylvain Gugger's avatar
Sylvain Gugger committed
2146

2147
        >>> # Download model and configuration from huggingface.co and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
2148
        >>> model = TFBertModel.from_pretrained("bert-base-uncased")
2149
        >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
Sylvain Gugger's avatar
Sylvain Gugger committed
2150
        >>> model = TFBertModel.from_pretrained("./test/saved_model/")
2151
        >>> # Update configuration during loading.
Sylvain Gugger's avatar
Sylvain Gugger committed
2152
        >>> model = TFBertModel.from_pretrained("bert-base-uncased", output_attentions=True)
2153
2154
        >>> assert model.config.output_attentions == True
        >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
Sylvain Gugger's avatar
Sylvain Gugger committed
2155
2156
        >>> config = BertConfig.from_json_file("./pt_model/my_pt_model_config.json")
        >>> model = TFBertModel.from_pretrained("./pt_model/my_pytorch_model.bin", from_pt=True, config=config)
2157
        ```"""
2158
2159
2160
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
2161
        ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
2162
2163
2164
2165
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
2166
        local_files_only = kwargs.pop("local_files_only", False)
2167
        use_auth_token = kwargs.pop("use_auth_token", None)
Julien Chaumond's avatar
Julien Chaumond committed
2168
        revision = kwargs.pop("revision", None)
2169
        mirror = kwargs.pop("mirror", None)
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
2170
        load_weight_prefix = kwargs.pop("load_weight_prefix", None)
2171
2172
2173
2174
2175
2176
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)

        user_agent = {"file_type": "model", "framework": "tensorflow", "from_auto_class": from_auto_class}
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline
thomwolf's avatar
thomwolf committed
2177

2178
2179
2180
2181
        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

2182
2183
2184
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
thomwolf's avatar
thomwolf committed
2185
            config, model_kwargs = cls.config_class.from_pretrained(
2186
2187
2188
                config_path,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
thomwolf's avatar
thomwolf committed
2189
                force_download=force_download,
2190
                resume_download=resume_download,
2191
2192
                proxies=proxies,
                local_files_only=local_files_only,
2193
                use_auth_token=use_auth_token,
Julien Chaumond's avatar
Julien Chaumond committed
2194
                revision=revision,
2195
2196
                _from_auto=from_auto_class,
                _from_pipeline=from_pipeline,
2197
                **kwargs,
thomwolf's avatar
thomwolf committed
2198
2199
2200
2201
            )
        else:
            model_kwargs = kwargs

Arthur's avatar
Arthur committed
2202
2203
2204
2205
        # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
        # index of the files.
        is_sharded = False
        sharded_metadata = None
thomwolf's avatar
thomwolf committed
2206
        # Load model
thomwolf's avatar
thomwolf committed
2207
        if pretrained_model_name_or_path is not None:
2208
            if os.path.isdir(pretrained_model_name_or_path):
2209
2210
2211
                if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint in priority if from_pt
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
2212
2213
2214
2215
                elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
                    # Load from a sharded PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
                    is_sharded = True
2216
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
thomwolf's avatar
thomwolf committed
2217
2218
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
Arthur's avatar
Arthur committed
2219
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
2220
                    # Load from a sharded TF 2.0 checkpoint
Arthur's avatar
Arthur committed
2221
2222
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
                    is_sharded = True
2223
2224
2225
2226
2227
2228
2229
                # At this stage we don't have a weight file so we will raise an error.
                elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
                    raise EnvironmentError(
                        f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
                        "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
                        "weights."
                    )
thomwolf's avatar
thomwolf committed
2230
                else:
2231
                    raise EnvironmentError(
2232
2233
                        f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
                        f"{pretrained_model_name_or_path}."
2234
                    )
Julien Chaumond's avatar
Julien Chaumond committed
2235
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
2236
                archive_file = pretrained_model_name_or_path
2237
2238
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"
thomwolf's avatar
thomwolf committed
2239
            else:
2240
                filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
thomwolf's avatar
thomwolf committed
2241
                archive_file = hf_bucket_url(
Julien Chaumond's avatar
Julien Chaumond committed
2242
                    pretrained_model_name_or_path,
2243
                    filename=filename,
Julien Chaumond's avatar
Julien Chaumond committed
2244
                    revision=revision,
2245
                    mirror=mirror,
thomwolf's avatar
thomwolf committed
2246
                )
thomwolf's avatar
thomwolf committed
2247
2248

            try:
2249
                # Load from URL or cache if already cached
2250
2251
2252
2253
2254
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
2255
2256
                    resume_download=resume_download,
                    local_files_only=local_files_only,
2257
                    use_auth_token=use_auth_token,
2258
                    user_agent=user_agent,
2259
                )
2260

2261
            except RepositoryNotFoundError:
2262
2263
2264
2265
2266
2267
                raise EnvironmentError(
                    f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
                    "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
                    "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
                    "login` and pass `use_auth_token=True`."
                )
2268
            except RevisionNotFoundError:
2269
2270
2271
2272
2273
                raise EnvironmentError(
                    f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
                    "this model name. Check the model page at "
                    f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
                )
2274
            except EntryNotFoundError:
2275
                if filename == TF2_WEIGHTS_NAME:
Arthur's avatar
Arthur committed
2276
2277
2278
2279
2280
2281
2282
                    try:
                        # Maybe the checkpoint is sharded, we try to grab the index name in this case.
                        archive_file = hf_bucket_url(
                            pretrained_model_name_or_path,
                            filename=TF2_WEIGHTS_INDEX_NAME,
                            revision=revision,
                            mirror=mirror,
2283
                        )
Arthur's avatar
Arthur committed
2284
2285
2286
2287
2288
2289
2290
2291
2292
                        resolved_archive_file = cached_path(
                            archive_file,
                            cache_dir=cache_dir,
                            force_download=force_download,
                            proxies=proxies,
                            resume_download=resume_download,
                            local_files_only=local_files_only,
                            use_auth_token=use_auth_token,
                            user_agent=user_agent,
2293
                        )
Arthur's avatar
Arthur committed
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
                        is_sharded = True
                    except EntryNotFoundError:
                        # Otherwise, maybe there is a TF or Flax model file.  We try those to give a helpful error
                        # message.
                        has_file_kwargs = {
                            "revision": revision,
                            "mirror": mirror,
                            "proxies": proxies,
                            "use_auth_token": use_auth_token,
                        }
                        if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
                            raise EnvironmentError(
                                f"{pretrained_model_name_or_path} does not appear to have a file named"
                                f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
                                " load this model from those weights."
                            )
                        else:
                            raise EnvironmentError(
                                f"{pretrained_model_name_or_path} does not appear to have a file named"
                                f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
                            )
2315
2316
2317
2318
                else:
                    raise EnvironmentError(
                        f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
                    )
2319
            except HTTPError as err:
2320
                raise EnvironmentError(
2321
2322
2323
2324
2325
                    f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
                    f"{err}"
                )
            except ValueError:
                raise EnvironmentError(
Sylvain Gugger's avatar
Sylvain Gugger committed
2326
2327
2328
2329
2330
                    f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
                    f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
                    f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your internet"
                    " connection or see how to run the library in offline mode at"
                    " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
2331
                )
2332
            except EnvironmentError:
2333
2334
2335
2336
2337
                raise EnvironmentError(
                    f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
                    "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
                    f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
                    f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
2338
                )
2339

thomwolf's avatar
thomwolf committed
2340
            if resolved_archive_file == archive_file:
2341
                logger.info(f"loading weights file {archive_file}")
thomwolf's avatar
thomwolf committed
2342
            else:
2343
                logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
thomwolf's avatar
thomwolf committed
2344
        else:
thomwolf's avatar
thomwolf committed
2345
            resolved_archive_file = None
thomwolf's avatar
thomwolf committed
2346

Arthur's avatar
Arthur committed
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
        # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
        if is_sharded:
            # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
            resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
                pretrained_model_name_or_path,
                resolved_archive_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                user_agent=user_agent,
                revision=revision,
                mirror=mirror,
            )

2364
2365
        config.name_or_path = pretrained_model_name_or_path

Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
2366
2367
2368
2369
2370
        # composed models, *e.g.* TFRag, require special treatment when it comes to loading
        # pre-trained weights.
        if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None:
            model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name")

thomwolf's avatar
thomwolf committed
2371
2372
2373
2374
        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
Julien Plu's avatar
Julien Plu committed
2375
2376
            from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model

thomwolf's avatar
thomwolf committed
2377
            # Load from a PyTorch checkpoint
Yih-Dar's avatar
Yih-Dar committed
2378
2379
2380
            return load_pytorch_checkpoint_in_tf2_model(
                model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info
            )
thomwolf's avatar
thomwolf committed
2381

Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
2382
2383
2384
2385
2386
2387
        # we might need to extend the variable scope for composite models
        if load_weight_prefix is not None:
            with tf.compat.v1.variable_scope(load_weight_prefix):
                model(model.dummy_inputs)  # build the network with dummy inputs
        else:
            model(model.dummy_inputs)  # build the network with dummy inputs
thomwolf's avatar
thomwolf committed
2388
2389
2390

        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
2391
        try:
Arthur's avatar
Arthur committed
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
            if is_sharded:
                for file in resolved_archive_file:
                    os.path.isfile(file), f"Error retrieving files {file}"

                missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
                    model,
                    resolved_archive_file,
                    ignore_mismatched_sizes=ignore_mismatched_sizes,
                )
            else:
                missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
                    model,
                    resolved_archive_file,
                    ignore_mismatched_sizes=ignore_mismatched_sizes,
                    _prefix=load_weight_prefix,
                )
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
        except OSError as e:
            try:
                with open(resolved_archive_file) as f:
                    if f.read().startswith("version"):
                        raise OSError(
                            "You seem to have cloned a repository without having git-lfs installed. Please install "
                            "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                            "you cloned."
                        )
                    else:
                        raise ValueError from e
            except (UnicodeDecodeError, ValueError):
                raise OSError(
                    "Unable to load weights from h5 file. "
                    "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
                )
thomwolf's avatar
thomwolf committed
2424

Julien Plu's avatar
Julien Plu committed
2425
        model(model.dummy_inputs)  # Make sure restore ops are run
thomwolf's avatar
thomwolf committed
2426

2427
2428
        if cls._keys_to_ignore_on_load_missing is not None:
            for pat in cls._keys_to_ignore_on_load_missing:
2429
2430
                missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

2431
2432
        if cls._keys_to_ignore_on_load_unexpected is not None:
            for pat in cls._keys_to_ignore_on_load_unexpected:
Julien Plu's avatar
Julien Plu committed
2433
2434
                unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

2435
2436
        if len(unexpected_keys) > 0:
            logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2437
2438
2439
2440
2441
2442
2443
                f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when"
                f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
                f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
                " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
                " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
                f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
                " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
2444
2445
            )
        else:
Julien Plu's avatar
Julien Plu committed
2446
2447
            logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")

thomwolf's avatar
thomwolf committed
2448
        if len(missing_keys) > 0:
2449
            logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2450
2451
2452
                f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at"
                f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
                " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
2453
            )
2454
        elif len(mismatched_keys) == 0:
2455
            logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2456
2457
2458
2459
                f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at"
                f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
                f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
                " training."
2460
            )
2461
2462
2463
2464
2465
2466
2467
2468
        if len(mismatched_keys) > 0:
            mismatched_warning = "\n".join(
                [
                    f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
                    for key, shape1, shape2 in mismatched_keys
                ]
            )
            logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2469
2470
2471
2472
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
                f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
                f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
                " to use it for predictions and inference."
2473
            )
Julien Plu's avatar
Julien Plu committed
2474

thomwolf's avatar
thomwolf committed
2475
        if output_loading_info:
2476
2477
2478
2479
2480
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "mismatched_keys": mismatched_keys,
            }
Julien Plu's avatar
Julien Plu committed
2481

thomwolf's avatar
thomwolf committed
2482
2483
            return model, loading_info

thomwolf's avatar
thomwolf committed
2484
        return model
thomwolf's avatar
WIP  
thomwolf committed
2485

2486

2487
2488
2489
2490
2491
2492
2493
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
TFPreTrainedModel.push_to_hub = copy_func(TFPreTrainedModel.push_to_hub)
TFPreTrainedModel.push_to_hub.__doc__ = TFPreTrainedModel.push_to_hub.__doc__.format(
    object="model", object_class="TFAutoModel", object_files="model checkpoint"
)


thomwolf's avatar
WIP  
thomwolf committed
2494
class TFConv1D(tf.keras.layers.Layer):
Sylvain Gugger's avatar
Sylvain Gugger committed
2495
2496
2497
2498
2499
2500
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

    Basically works like a linear layer but the weights are transposed.

    Args:
2501
        nf (`int`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2502
            The number of output features.
2503
        nx (`int`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2504
            The number of input features.
2505
        initializer_range (`float`, *optional*, defaults to 0.02):
Sylvain Gugger's avatar
Sylvain Gugger committed
2506
2507
            The standard deviation to use to initialize the weights.
        kwargs:
2508
            Additional keyword arguments passed along to the `__init__` of `tf.keras.layers.Layer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
2509
2510
    """

thomwolf's avatar
thomwolf committed
2511
    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
2512
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
2513
        self.nf = nf
thomwolf's avatar
thomwolf committed
2514
        self.nx = nx
thomwolf's avatar
thomwolf committed
2515
        self.initializer_range = initializer_range
thomwolf's avatar
thomwolf committed
2516
2517
2518

    def build(self, input_shape):
        self.weight = self.add_weight(
2519
2520
2521
            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
        )
        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
thomwolf's avatar
thomwolf committed
2522

thomwolf's avatar
WIP  
thomwolf committed
2523
    def call(self, x):
thomwolf's avatar
thomwolf committed
2524
        bz, sl = shape_list(x)[:2]
thomwolf's avatar
thomwolf committed
2525

thomwolf's avatar
thomwolf committed
2526
        x = tf.reshape(x, [-1, self.nx])
thomwolf's avatar
thomwolf committed
2527
        x = tf.matmul(x, self.weight) + self.bias
thomwolf's avatar
thomwolf committed
2528
2529

        x = tf.reshape(x, [bz, sl, self.nf])
thomwolf's avatar
thomwolf committed
2530

thomwolf's avatar
WIP  
thomwolf committed
2531
        return x
thomwolf's avatar
thomwolf committed
2532
2533


thomwolf's avatar
thomwolf committed
2534
class TFSharedEmbeddings(tf.keras.layers.Layer):
Stas Bekman's avatar
Stas Bekman committed
2535
    r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
2536
    Construct shared token embeddings.
2537

Sylvain Gugger's avatar
Sylvain Gugger committed
2538
2539
    The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language
    modeling.
Sylvain Gugger's avatar
Sylvain Gugger committed
2540
2541

    Args:
2542
        vocab_size (`int`):
2543
            The size of the vocabulary, e.g., the number of unique tokens.
2544
        hidden_size (`int`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2545
            The size of the embedding vectors.
2546
        initializer_range (`float`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2547
            The standard deviation to use when initializing the weights. If no value is provided, it will default to
2548
            \\(1/\sqrt{hidden\_size}\\).
Sylvain Gugger's avatar
Sylvain Gugger committed
2549
        kwargs:
2550
            Additional keyword arguments passed along to the `__init__` of `tf.keras.layers.Layer`.
Sylvain Gugger's avatar
Sylvain Gugger committed
2551
2552
2553
    """

    def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
2554
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
2555
2556
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
2557
        self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range
thomwolf's avatar
thomwolf committed
2558
2559

    def build(self, input_shape):
Sylvain Gugger's avatar
Sylvain Gugger committed
2560
2561
2562
        """
        Build shared token embedding layer Shared weights logic adapted from
        https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
thomwolf's avatar
thomwolf committed
2563
2564
        """
        self.weight = self.add_weight(
2565
2566
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
Julien Chaumond's avatar
Julien Chaumond committed
2567
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
2568

Julien Plu's avatar
Julien Plu committed
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
    def get_config(self):
        config = {
            "vocab_size": self.vocab_size,
            "hidden_size": self.hidden_size,
            "initializer_range": self.initializer_range,
        }
        base_config = super().get_config()

        return dict(list(base_config.items()) + list(config.items()))

Sylvain Gugger's avatar
Sylvain Gugger committed
2579
2580
2581
2582
    def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
        """
        Get token embeddings of inputs or decode final hidden state.

thomwolf's avatar
thomwolf committed
2583
        Args:
2584
2585
            inputs (`tf.Tensor`):
                In embedding mode, should be an int64 tensor with shape `[batch_size, length]`.
Sylvain Gugger's avatar
Sylvain Gugger committed
2586

2587
2588
                In linear mode, should be a float tensor with shape `[batch_size, length, hidden_size]`.
            mode (`str`, defaults to `"embedding"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2589
2590
               A valid value is either `"embedding"` or `"linear"`, the first one indicates that the layer should be
               used as an embedding layer, the second one that the layer should be used as a linear decoder.
Sylvain Gugger's avatar
Sylvain Gugger committed
2591

thomwolf's avatar
thomwolf committed
2592
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2593
2594
            `tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape `[batch_size, length,
            embedding_size]`.
Sylvain Gugger's avatar
Sylvain Gugger committed
2595

2596
            In linear mode, the output is a float32 with shape `[batch_size, length, vocab_size]`.
Sylvain Gugger's avatar
Sylvain Gugger committed
2597

thomwolf's avatar
thomwolf committed
2598
        Raises:
2599
            ValueError: if `mode` is not valid.
2600

Sylvain Gugger's avatar
Sylvain Gugger committed
2601
2602
        Shared weights logic is adapted from
        [here](https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24).
thomwolf's avatar
thomwolf committed
2603
2604
2605
2606
2607
2608
        """
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
2609
            raise ValueError(f"mode {mode} is not valid.")
thomwolf's avatar
thomwolf committed
2610
2611
2612
2613
2614
2615
2616

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """
Julien Plu's avatar
Julien Plu committed
2617
        Computes logits by running inputs through a linear layer.
thomwolf's avatar
thomwolf committed
2618

Julien Plu's avatar
Julien Plu committed
2619
2620
2621
2622
2623
2624
2625
        Args:
            inputs: A float32 tensor with shape [..., hidden_size]

        Returns:
            float32 tensor with shape [..., vocab_size].
        """
        first_dims = shape_list(inputs)[:-1]
thomwolf's avatar
thomwolf committed
2626
2627
2628
2629
2630
2631
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


thomwolf's avatar
thomwolf committed
2632
class TFSequenceSummary(tf.keras.layers.Layer):
Julien Plu's avatar
Julien Plu committed
2633
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
2634
2635
2636
    Compute a single vector summary of a sequence hidden states.

    Args:
2637
        config ([`PretrainedConfig`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
2638
2639
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):
Sylvain Gugger's avatar
Sylvain Gugger committed
2640

2641
            - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
Sylvain Gugger's avatar
Sylvain Gugger committed
2642

2643
2644
2645
2646
2647
                - `"last"` -- Take the last token hidden state (like XLNet)
                - `"first"` -- Take the first token hidden state (like Bert)
                - `"mean"` -- Take the mean of all tokens hidden states
                - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - `"attn"` -- Not implemented now, use multi-head attention
Sylvain Gugger's avatar
Sylvain Gugger committed
2648

2649
            - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
Sylvain Gugger's avatar
Sylvain Gugger committed
2650
2651
2652
2653
2654
2655
            - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
              (otherwise to `config.hidden_size`).
            - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
              another string or `None` will add no activation.
            - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
            - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
Sylvain Gugger's avatar
Sylvain Gugger committed
2656

2657
        initializer_range (`float`, defaults to 0.02): The standard deviation to use to initialize the weights.
Sylvain Gugger's avatar
Sylvain Gugger committed
2658
        kwargs:
2659
            Additional keyword arguments passed along to the `__init__` of `tf.keras.layers.Layer`.
thomwolf's avatar
thomwolf committed
2660
    """
2661

Sylvain Gugger's avatar
Sylvain Gugger committed
2662
    def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
2663
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
2664

2665
2666
        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
2667
2668
2669
2670
2671
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

2672
        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
2673
        if self.has_summary:
2674
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
thomwolf's avatar
thomwolf committed
2675
2676
2677
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
2678
2679
2680
            self.summary = tf.keras.layers.Dense(
                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
            )
thomwolf's avatar
thomwolf committed
2681

2682
2683
2684
2685
2686
        self.has_activation = False
        activation_string = getattr(config, "summary_activation", None)
        if activation_string is not None:
            self.has_activation = True
            self.activation = get_tf_activation(activation_string)
thomwolf's avatar
thomwolf committed
2687

2688
        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
2689
        if self.has_first_dropout:
thomwolf's avatar
thomwolf committed
2690
2691
            self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)

2692
        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
2693
        if self.has_last_dropout:
thomwolf's avatar
thomwolf committed
2694
2695
            self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)

Julien Plu's avatar
Julien Plu committed
2696
    def call(self, inputs, cls_index=None, training=False):
thomwolf's avatar
thomwolf committed
2697
2698
2699
2700
2701
2702
2703
        if not isinstance(inputs, (dict, tuple, list)):
            hidden_states = inputs
        elif isinstance(inputs, (tuple, list)):
            hidden_states = inputs[0]
            cls_index = inputs[1] if len(inputs) > 1 else None
            assert len(inputs) <= 2, "Too many inputs."
        else:
2704
            hidden_states = inputs.get("hidden_states")
2705
            cls_index = inputs.get("cls_index", None)
thomwolf's avatar
thomwolf committed
2706

2707
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
2708
            output = hidden_states[:, -1]
2709
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
2710
            output = hidden_states[:, 0]
2711
        elif self.summary_type == "mean":
Lysandre's avatar
Lysandre committed
2712
            output = tf.reduce_mean(hidden_states, axis=1)
2713
        elif self.summary_type == "cls_index":
2714
            hidden_shape = shape_list(hidden_states)  # e.g. [batch, num choices, seq length, hidden dims]
thomwolf's avatar
thomwolf committed
2715
            if cls_index is None:
2716
2717
2718
                cls_index = tf.fill(
                    hidden_shape[:-2], hidden_shape[-2] - 1
                )  # A tensor full of shape [batch] or [batch, num choices] full of sequence length
2719
2720
            cls_shape = shape_list(cls_index)
            if len(cls_shape) <= len(hidden_shape) - 2:
2721
                cls_index = tf.expand_dims(cls_index, axis=-1)
2722
            # else:
2723
2724
            # cls_index = cls_index[..., tf.newaxis]
            # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
2725
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
2726
            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
2727
2728
2729
2730
            output = tf.squeeze(
                output, axis=len(hidden_shape) - 2
            )  # shape of output: (batch, num choices, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
2731
2732
            raise NotImplementedError

2733
2734
        if self.has_first_dropout:
            output = self.first_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
2735

2736
        if self.has_summary:
2737
            output = self.summary(output)
thomwolf's avatar
thomwolf committed
2738

2739
        if self.has_activation:
thomwolf's avatar
thomwolf committed
2740
2741
            output = self.activation(output)

2742
2743
        if self.has_last_dropout:
            output = self.last_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
2744
2745
2746

        return output

2747
2748
2749
2750
2751
2752
    @classmethod
    def register_for_auto_class(cls, auto_class="TFAutoModel"):
        """
        Register this class with a given auto class. This should only be used for custom models as the ones in the
        library are already mapped with an auto class.

2753
2754
2755
2756
2757
2758
        <Tip warning={true}>

        This API is experimental and may have some slight breaking changes in the next releases.

        </Tip>

2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
        Args:
            auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
                The auto class to register this new model with.
        """
        if not isinstance(auto_class, str):
            auto_class = auto_class.__name__

        import transformers.models.auto as auto_module

        if not hasattr(auto_module, auto_class):
            raise ValueError(f"{auto_class} is not a valid auto class.")

        cls._auto_class = auto_class

2773

Sylvain Gugger's avatar
Sylvain Gugger committed
2774
2775
def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal:
    """
2776
    Creates a `tf.initializers.TruncatedNormal` with the given range.
Sylvain Gugger's avatar
Sylvain Gugger committed
2777

Julien Chaumond's avatar
Julien Chaumond committed
2778
    Args:
2779
        initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range.
Sylvain Gugger's avatar
Sylvain Gugger committed
2780

Julien Chaumond's avatar
Julien Chaumond committed
2781
    Returns:
2782
        `tf.initializers.TruncatedNormal`: The truncated normal initializer.
Julien Chaumond's avatar
Julien Chaumond committed
2783
2784
    """
    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
2785
2786


Sam Shleifer's avatar
Sam Shleifer committed
2787
2788
class TFWrappedEmbeddings:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
2789
2790
2791
    this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
    weight restoring. Also it makes sure that the layer is called from the correct scope to avoid problem with
    saving/storing the correct weights
Sam Shleifer's avatar
Sam Shleifer committed
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
    """

    def __init__(self, layer, abs_scope_name=None):
        self._layer = layer
        self._abs_scope_name = abs_scope_name

    def call(self, inputs, mode="embedding"):
        if self._abs_scope_name is None:
            return self._layer.call(inputs, mode)

        # if an abs scope name is given to the embedding variable, call variable from absolute scope
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
            with tf.name_scope(abs_scope_name.original_name_scope):
                return self._layer.call(inputs, mode)

    def __call__(self, inputs, mode="embedding"):
        if self._abs_scope_name is None:
            return self._layer(inputs, mode)

        # if an abs scope name is given to the embedding variable, call variable from absolute scope
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
            with tf.name_scope(abs_scope_name.original_name_scope):
                return self._layer(inputs, mode)