modeling_tf_utils.py 81.9 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
Julien Plu's avatar
Julien Plu committed
17

18
import functools
Julien Plu's avatar
Julien Plu committed
19
import inspect
thomwolf's avatar
thomwolf committed
20
import os
21
import re
Julien Plu's avatar
Julien Plu committed
22
import warnings
Sylvain Gugger's avatar
Sylvain Gugger committed
23
from typing import Dict, List, Optional, Union
thomwolf's avatar
thomwolf committed
24

Aymeric Augustin's avatar
Aymeric Augustin committed
25
import h5py
Julien Chaumond's avatar
Julien Chaumond committed
26
import numpy as np
thomwolf's avatar
thomwolf committed
27
import tensorflow as tf
Julien Plu's avatar
Julien Plu committed
28
from tensorflow.python.keras import backend as K
Matt's avatar
Matt committed
29
from tensorflow.python.keras.engine import data_adapter
thomwolf's avatar
thomwolf committed
30
from tensorflow.python.keras.saving import hdf5_format
thomwolf's avatar
thomwolf committed
31
32

from .configuration_utils import PretrainedConfig
Julien Plu's avatar
Julien Plu committed
33
34
35
36
37
from .file_utils import (
    DUMMY_INPUTS,
    TF2_WEIGHTS_NAME,
    WEIGHTS_NAME,
    ModelOutput,
Sylvain Gugger's avatar
Sylvain Gugger committed
38
    PushToHubMixin,
Julien Plu's avatar
Julien Plu committed
39
    cached_path,
40
    copy_func,
Julien Plu's avatar
Julien Plu committed
41
    hf_bucket_url,
42
    is_offline_mode,
Julien Plu's avatar
Julien Plu committed
43
44
    is_remote_url,
)
45
from .generation_tf_utils import TFGenerationMixin
Julien Plu's avatar
Julien Plu committed
46
from .tokenization_utils_base import BatchEncoding
Lysandre Debut's avatar
Lysandre Debut committed
47
from .utils import logging
thomwolf's avatar
thomwolf committed
48

Aymeric Augustin's avatar
Aymeric Augustin committed
49

Lysandre Debut's avatar
Lysandre Debut committed
50
logger = logging.get_logger(__name__)
51
tf_logger = tf.get_logger()
thomwolf's avatar
thomwolf committed
52

Julien Plu's avatar
Julien Plu committed
53
54
55
56
TFModelInputType = Union[
    List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor
]

57

Matt's avatar
Matt committed
58
59
60
61
def dummy_loss(y_true, y_pred):
    return tf.reduce_mean(y_pred)


62
class TFModelUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
63
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
64
    A few utilities for :obj:`tf.keras.Model`, to be used as a mixin.
Julien Chaumond's avatar
Julien Chaumond committed
65
66
67
68
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
69
70
71
72
73
74
75
76
        Get the number of (optionally, trainable) parameters in the model.

        Args:
            only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to return only the number of trainable parameters

        Returns:
            :obj:`int`: The number of parameters.
Julien Chaumond's avatar
Julien Chaumond committed
77
78
79
80
81
82
83
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


84
def keras_serializable(cls):
85
86
87
88
    """
    Decorate a Keras Layer class to support Keras serialization.

    This is done by:
Sylvain Gugger's avatar
Sylvain Gugger committed
89
90
91
92
93

    1. Adding a :obj:`transformers_config` dict to the Keras config dictionary in :obj:`get_config` (called by Keras at
       serialization time.
    2. Wrapping :obj:`__init__` to accept that :obj:`transformers_config` dict (passed by Keras at deserialization
       time) and convert it to a config object for the actual layer initializer.
Sylvain Gugger's avatar
Sylvain Gugger committed
94
95
    3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not
       need to be supplied in :obj:`custom_objects` in the call to :obj:`tf.keras.models.load_model`.
Sylvain Gugger's avatar
Sylvain Gugger committed
96
97
98
99
100
101
102
103

    Args:
        cls (a :obj:`tf.keras.layers.Layers subclass`):
            Typically a :obj:`TF.MainLayer` class in this project, in general must accept a :obj:`config` argument to
            its initializer.

    Returns:
        The same class object, with modifications for Keras deserialization.
104
    """
105
    initializer = cls.__init__
106

107
108
109
110
    config_class = getattr(cls, "config_class", None)
    if config_class is None:
        raise AttributeError("Must set `config_class` to use @keras_serializable")

111
    @functools.wraps(initializer)
112
    def wrapped_init(self, *args, **kwargs):
113
114
115
116
        config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None)

        if isinstance(config, dict):
            config = config_class.from_dict(config)
117
            initializer(self, config, *args, **kwargs)
118
119
120
121
122
        elif isinstance(config, PretrainedConfig):
            if len(args) > 0:
                initializer(self, *args, **kwargs)
            else:
                initializer(self, config, *args, **kwargs)
123
        else:
124
125
126
            raise ValueError("Must pass either `config` (PretrainedConfig) or `config` (dict)")

        self._config = config
Julien Plu's avatar
Julien Plu committed
127
        self._kwargs = kwargs
128

129
130
131
132
133
134
135
136
    cls.__init__ = wrapped_init

    if not hasattr(cls, "get_config"):
        raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
    if hasattr(cls.get_config, "_is_default"):

        def get_config(self):
            cfg = super(cls, self).get_config()
137
            cfg["config"] = self._config.to_dict()
Julien Plu's avatar
Julien Plu committed
138
            cfg.update(self._kwargs)
139
140
141
142
            return cfg

        cls.get_config = get_config

143
    cls._keras_serializable = True
144
145
146
    if hasattr(tf.keras.utils, "register_keras_serializable"):
        cls = tf.keras.utils.register_keras_serializable()(cls)
    return cls
147
148


149
class TFCausalLanguageModelingLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
150
151
152
153
154
155
156
157
158
    """
    Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.

    .. note::

        Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    """

159
160
161
162
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
Muennighoff's avatar
Muennighoff committed
163
        # make sure only labels that are not equal to -100 affect the loss
164
        active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
165
166
167
168
169
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
        return loss_fn(labels, reduced_logits)


Julien Plu's avatar
Julien Plu committed
170
class TFQuestionAnsweringLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
171
    """
172
    Loss function suitable for question answering.
Sylvain Gugger's avatar
Sylvain Gugger committed
173
174
    """

Julien Plu's avatar
Julien Plu committed
175
176
177
178
179
180
181
182
183
184
185
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        start_loss = loss_fn(labels["start_position"], logits[0])
        end_loss = loss_fn(labels["end_position"], logits[1])

        return (start_loss + end_loss) / 2.0


class TFTokenClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
186
187
188
189
190
191
192
193
194
    """
    Loss function suitable for token classification.

    .. note::

        Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    """

Julien Plu's avatar
Julien Plu committed
195
196
197
198
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
199
200
        # make sure only labels that are not equal to -100
        # are taken into account as loss
201
        if tf.math.reduce_any(labels == -1):
Julien Plu's avatar
Julien Plu committed
202
203
204
205
            warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
            active_loss = tf.reshape(labels, (-1,)) != -1
        else:
            active_loss = tf.reshape(labels, (-1,)) != -100
Julien Plu's avatar
Julien Plu committed
206
207
208
209
210
211
212
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)

        return loss_fn(labels, reduced_logits)


class TFSequenceClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
213
214
215
216
    """
    Loss function suitable for sequence classification.
    """

Julien Plu's avatar
Julien Plu committed
217
    def compute_loss(self, labels, logits):
218
        if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
Julien Plu's avatar
Julien Plu committed
219
220
221
222
223
224
225
226
227
            loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
        else:
            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True, reduction=tf.keras.losses.Reduction.NONE
            )

        return loss_fn(labels, logits)


Matt's avatar
Matt committed
228
class TFMultipleChoiceLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
229
230
    """Loss function suitable for multiple choice tasks."""

Matt's avatar
Matt committed
231
232
233
234
235
236
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        return loss_fn(labels, logits)

Sylvain Gugger's avatar
Sylvain Gugger committed
237
238
239

class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
    """
Lysandre's avatar
Lysandre committed
240
    Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.
Sylvain Gugger's avatar
Sylvain Gugger committed
241

Lysandre's avatar
Lysandre committed
242
    .. note::
Sylvain Gugger's avatar
Sylvain Gugger committed
243

Lysandre's avatar
Lysandre committed
244
245
         Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
    """
Julien Plu's avatar
Julien Plu committed
246
247


248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
class TFNextSentencePredictionLoss:
    """
    Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.

    .. note::
         Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
    """

    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        # make sure only labels that are not equal to -100
        # are taken into account as loss
        next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
        next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
        next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)

        return loss_fn(next_sentence_label, next_sentence_reduced_logits)


269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def booleans_processing(config, **kwargs):
    """
    Process the input booleans of each model in order to be sure they are compliant with the execution mode (eager or
    graph)

    Args:
        config (:class:`~transformers.PretrainedConfig`):
            The config of the running model.
        **kwargs:
            The boolean parameters

    Returns:
        A dictionary with the proper values for each boolean
    """
    final_booleans = {}

    if tf.executing_eagerly():
        final_booleans["output_attentions"] = (
            kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
        )
        final_booleans["output_hidden_states"] = (
            kwargs["output_hidden_states"]
            if kwargs["output_hidden_states"] is not None
            else config.output_hidden_states
        )
Julien Plu's avatar
Julien Plu committed
294
295
296
        final_booleans["return_dict"] = (
            kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict
        )
297
298
299
300
301

        if "use_cache" in kwargs:
            final_booleans["use_cache"] = kwargs["use_cache"] if kwargs["use_cache"] is not None else config.use_cache
    else:
        if (
302
303
            kwargs["output_attentions"] not in (None, config.output_attentions)
            or kwargs["output_hidden_states"] not in (None, config.output_hidden_states)
Matt's avatar
Matt committed
304
            or ("use_cache" in kwargs and kwargs["use_cache"] not in (None, config.use_cache))
305
        ):
306
            tf_logger.warning(
307
308
309
310
311
312
313
                "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model."
                "They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`)."
            )

        final_booleans["output_attentions"] = config.output_attentions
        final_booleans["output_hidden_states"] = config.output_hidden_states

314
        if kwargs.get("return_dict", None) not in (None, True):
315
316
317
            tf_logger.warning(
                "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`."
            )
Julien Plu's avatar
Julien Plu committed
318
        final_booleans["return_dict"] = True
319
320
321
322
323
324
325
326
327

        if "use_cache" in kwargs:
            final_booleans["use_cache"] = config.use_cache

    return final_booleans


def input_processing(func, config, input_ids, **kwargs):
    """
Julien Plu's avatar
Julien Plu committed
328
329
330
    Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
    has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32',
    name="input_ids")` otherwise the order of the tensors will not be guaranteed during the training.
331
332
333
334
335
336
337
338
339
340
341
342

    Args:
        func (:obj:`callable`):
            The callable function of the TensorFlow model.
        config (:class:`~transformers.PretrainedConfig`):
            The config of the running model.
        **kwargs:
            The inputs of the model.

    Returns:
        Two lists, one for the missing layers, and another one for the unexpected layers.
    """
Julien Plu's avatar
Julien Plu committed
343
344
    signature = dict(inspect.signature(func).parameters)
    signature.pop("kwargs", None)
Julien Plu's avatar
Julien Plu committed
345
    signature.pop("self", None)
Julien Plu's avatar
Julien Plu committed
346
347
    parameter_names = list(signature.keys())
    output = {}
348
    allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
Julien Plu's avatar
Julien Plu committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369

    if "inputs" in kwargs["kwargs_call"]:
        warnings.warn(
            "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
            FutureWarning,
        )

        output["input_ids"] = kwargs["kwargs_call"].pop("inputs")

    if "decoder_cached_states" in kwargs["kwargs_call"]:
        warnings.warn(
            "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
            FutureWarning,
        )
        output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")

    if len(kwargs["kwargs_call"]) > 0:
        raise ValueError(
            f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
        )

Julien Plu's avatar
Julien Plu committed
370
371
    kwargs.pop("kwargs_call")

Julien Plu's avatar
Julien Plu committed
372
373
374
375
    for k, v in kwargs.items():
        if isinstance(v, allowed_types) or v is None:
            output[k] = v
        else:
Julien Plu's avatar
Julien Plu committed
376
            raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
Julien Plu's avatar
Julien Plu committed
377
378
379
380
381

    if isinstance(input_ids, (tuple, list)):
        for i, input in enumerate(input_ids):
            # EagerTensors don't allow to use the .name property so we check for a real Tensor
            if type(input) == tf.Tensor:
Julien Plu's avatar
Julien Plu committed
382
383
                # Tensor names have always the pattern `name:id` then we check only the
                # `name` part
Julien Plu's avatar
Julien Plu committed
384
385
386
387
388
                tensor_name = input.name.split(":")[0]

                if tensor_name in parameter_names:
                    output[tensor_name] = input
                else:
Julien Plu's avatar
Julien Plu committed
389
                    output[parameter_names[i]] = input
Julien Plu's avatar
Julien Plu committed
390
391
392
393
            elif isinstance(input, allowed_types) or input is None:
                output[parameter_names[i]] = input
            else:
                raise ValueError(
Julien Plu's avatar
Julien Plu committed
394
                    f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
Julien Plu's avatar
Julien Plu committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
                )
    elif isinstance(input_ids, (dict, BatchEncoding)):
        if "inputs" in input_ids:
            warnings.warn(
                "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
                FutureWarning,
            )

            output["input_ids"] = input_ids.pop("inputs")

        if "decoder_cached_states" in input_ids:
            warnings.warn(
                "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            output["past_key_values"] = input_ids.pop("decoder_cached_states")

        for k, v in dict(input_ids).items():
413
            if isinstance(v, allowed_types) or v is None:
Julien Plu's avatar
Julien Plu committed
414
                output[k] = v
415
            elif k not in parameter_names and "args" not in parameter_names:
416
                logger.warning(
417
418
419
420
                    f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
                )
                continue
            else:
Julien Plu's avatar
Julien Plu committed
421
                raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
Julien Plu's avatar
Julien Plu committed
422
423
424
425
426
    else:
        if isinstance(input_ids, tf.Tensor) or input_ids is None:
            output[parameter_names[0]] = input_ids
        else:
            raise ValueError(
Julien Plu's avatar
Julien Plu committed
427
                f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
Julien Plu's avatar
Julien Plu committed
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
            )

    for name in parameter_names:
        if name not in list(output.keys()) and name != "args":
            output[name] = kwargs.pop(name, signature[name].default)

    # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
    # So to respect the proper output we have to add this exception
    if "args" in output:
        if output["args"] is not None and type(output["args"]) == tf.Tensor:
            tensor_name = output["args"].name.split(":")[0]
            output[tensor_name] = output["args"]
        else:
            # `args` in this case is always the first parameter, then `input_ids`
            output["input_ids"] = output["args"]

        del output["args"]

    if "kwargs" in output:
        del output["kwargs"]

449
450
451
452
453
454
455
456
457
458
459
460
461
    boolean_dict = {
        k: v
        for k, v in output.items()
        if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
    }

    output.update(
        booleans_processing(
            config=config,
            **boolean_dict,
        )
    )

Julien Plu's avatar
Julien Plu committed
462
463
464
    return output


465
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
Julien Plu's avatar
Julien Plu committed
466
    """
Julien Plu's avatar
Julien Plu committed
467
    Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
Julien Plu's avatar
Julien Plu committed
468
469
470
471
472
473

    Args:
        model (:obj:`tf.keras.models.Model`):
            The model to load the weights into.
        resolved_archive_file (:obj:`str`):
            The location of the H5 file.
474
475
        ignore_mismatched_sizes (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to ignore weights with shapes that don't match between the checkpoint of the model.
Julien Plu's avatar
Julien Plu committed
476
477

    Returns:
478
479
        Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
        mismatched layers.
Julien Plu's avatar
Julien Plu committed
480
481
482
    """
    missing_layers = []
    unexpected_layers = []
483
    mismatched_layers = []
Julien Plu's avatar
Julien Plu committed
484

Julien Plu's avatar
Julien Plu committed
485
    # Read the H5 file
Julien Plu's avatar
Julien Plu committed
486
    with h5py.File(resolved_archive_file, "r") as f:
Julien Plu's avatar
Julien Plu committed
487
488
        # Retrieve the name of each layer from the H5 file
        saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
Julien Plu's avatar
Julien Plu committed
489

Julien Plu's avatar
Julien Plu committed
490
491
        # Find the missing layers from the high level list of layers
        missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name)
Julien Plu's avatar
Julien Plu committed
492

Julien Plu's avatar
Julien Plu committed
493
494
495
496
        # Find the unexpected layers from the high level list of layers
        unexpected_layers = list(saved_h5_model_layers_name - set([layer.name for layer in model.layers]))
        saved_weight_names_set = set()
        symbolic_weights_names = set()
Julien Plu's avatar
Julien Plu committed
497
498
        weight_value_tuples = []

Julien Plu's avatar
Julien Plu committed
499
500
        # Compute missing and unexpected sub layers
        # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
Julien Plu's avatar
Julien Plu committed
501
        for layer in model.layers:
Julien Plu's avatar
Julien Plu committed
502
503
504
505
506
            # if layer_name from the H5 file belongs to the layers from the instantiated model
            if layer.name in saved_h5_model_layers_name:
                # Get the H5 layer object from its name
                h5_layer_object = f[layer.name]
                # Get all the weights as a list from the layer object
Julien Plu's avatar
Julien Plu committed
507
                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
Julien Plu's avatar
Julien Plu committed
508
                saved_weights = {}
Julien Plu's avatar
Julien Plu committed
509

Julien Plu's avatar
Julien Plu committed
510
511
512
513
                # Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
                # And a set with only the names
                for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
                    # TF names always start with the model name so we ignore it
Julien Plu's avatar
Julien Plu committed
514
                    name = "/".join(weight_name.split("/")[1:])
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
515
516
517
518

                    if _prefix is not None:
                        name = _prefix + "/" + name

Julien Plu's avatar
Julien Plu committed
519
                    saved_weights[name] = np.asarray(h5_layer_object[weight_name])
Julien Plu's avatar
Julien Plu committed
520

Julien Plu's avatar
Julien Plu committed
521
522
523
524
                    # Add the updated name to the final list for computing missing/unexpected values
                    saved_weight_names_set.add(name)

                # Loop over each weights from the instantiated model and compare with the weights from the H5 file
Julien Plu's avatar
Julien Plu committed
525
                for symbolic_weight in symbolic_weights:
Julien Plu's avatar
Julien Plu committed
526
                    # TF names always start with the model name so we ignore it
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
527
528
529
530
531
532
533
534
                    if _prefix is not None:
                        delimeter = len(_prefix.split("/"))
                        symbolic_weight_name = "/".join(
                            symbolic_weight.name.split("/")[:delimeter]
                            + symbolic_weight.name.split("/")[delimeter + 1 :]
                        )
                    else:
                        symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
Julien Plu's avatar
Julien Plu committed
535
536
537
538
539

                    # here we check if the current weight is among the weights from the H5 file
                    # If yes, get the weight_value of the corresponding weight from the H5 file
                    # If not, make the value to None
                    saved_weight_value = saved_weights.get(symbolic_weight_name, None)
Julien Plu's avatar
Julien Plu committed
540

Julien Plu's avatar
Julien Plu committed
541
542
                    # Add the updated name to the final list for computing missing/unexpected values
                    symbolic_weights_names.add(symbolic_weight_name)
Julien Plu's avatar
Julien Plu committed
543

Julien Plu's avatar
Julien Plu committed
544
545
546
                    # If the current weight is found
                    if saved_weight_value is not None:
                        # Check if the shape of the current weight and the one from the H5 file are different
Julien Plu's avatar
Julien Plu committed
547
                        if K.int_shape(symbolic_weight) != saved_weight_value.shape:
Julien Plu's avatar
Julien Plu committed
548
549
                            # If yes we reshape the weight from the H5 file accordingly to the current weight
                            # If the two shapes are not compatible we raise an issue
Julien Plu's avatar
Julien Plu committed
550
551
                            try:
                                array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
552
553
554
555
556
557
558
559
                            except ValueError as e:
                                if ignore_mismatched_sizes:
                                    mismatched_layers.append(
                                        (symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
                                    )
                                    continue
                                else:
                                    raise e
Julien Plu's avatar
Julien Plu committed
560
561
562
                        else:
                            array = saved_weight_value

Julien Plu's avatar
Julien Plu committed
563
                        # We create the tuple that will be loaded and add it to the final list
Julien Plu's avatar
Julien Plu committed
564
565
                        weight_value_tuples.append((symbolic_weight, array))

Julien Plu's avatar
Julien Plu committed
566
    # Load all the weights
Julien Plu's avatar
Julien Plu committed
567
568
    K.batch_set_value(weight_value_tuples)

Julien Plu's avatar
Julien Plu committed
569
570
571
572
    # Compute the missing and unexpected layers
    missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
    unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))

573
    return missing_layers, unexpected_layers, mismatched_layers
Julien Plu's avatar
Julien Plu committed
574

Julien Plu's avatar
Julien Plu committed
575

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
def init_copy_embeddings(old_embeddings, new_num_tokens):
    r"""
    This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case
    new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be
    kept or not. Example:

        - if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4]

            -  mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1]
        - if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5]

            - mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4]
    """
    old_num_tokens, old_embedding_dim = shape_list(old_embeddings)
    size_diff = new_num_tokens - old_num_tokens

    # initialize new embeddings
    # Copy token embeddings from the previous ones
    if tf.math.greater(size_diff, 0):
        # if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size
        # and we create a mask to properly identify the padded values and be replaced by the values of the newly created
        # embeddings
        current_weights = tf.pad(
            old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1
        )
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True)
        mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False)
    else:
        # if the new size if lower than the old one, we take the current embeddings until the new size
        current_weights = tf.slice(
            old_embeddings.value(),
            tf.convert_to_tensor([0, 0]),
            tf.convert_to_tensor([new_num_tokens, old_embedding_dim]),
        )
        mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True)

    return mask, current_weights


Sylvain Gugger's avatar
Sylvain Gugger committed
616
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushToHubMixin):
617
618
    r"""
    Base class for all TF models.
thomwolf's avatar
thomwolf committed
619

620
621
    :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods
    for loading, downloading and saving models as well as a few methods common to all models to:
thomwolf's avatar
thomwolf committed
622

623
624
        * resize the input embeddings,
        * prune heads in the self-attention heads.
thomwolf's avatar
thomwolf committed
625

626
    Class attributes (overridden by derived classes):
Sylvain Gugger's avatar
Sylvain Gugger committed
627

628
629
630
631
        - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
          :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
        - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
          derived classes of the same architecture adding modules on top of the base model.
thomwolf's avatar
thomwolf committed
632
633
634
    """
    config_class = None
    base_model_prefix = ""
635
636
637
638
639
640
    # a list of re pattern of tensor names to ignore from the model when loading the model weights
    # (and avoid unnecessary warnings).
    _keys_to_ignore_on_load_missing = None
    # a list of re pattern of tensor names to ignore from the weights when loading the model weights
    # (and avoid unnecessary warnings).
    _keys_to_ignore_on_load_unexpected = None
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
641
    _requires_load_weight_prefix = False
thomwolf's avatar
thomwolf committed
642

643
    @property
644
645
    def dummy_inputs(self) -> Dict[str, tf.Tensor]:
        """
Julien Plu's avatar
Julien Plu committed
646
647
648
649
        Dummy inputs to build the network.

        Returns:
            :obj:`Dict[str, tf.Tensor]`: The dummy inputs.
650
        """
Julien Plu's avatar
Julien Plu committed
651
652
653
        return {
            "input_ids": tf.constant(DUMMY_INPUTS),
        }
thomwolf's avatar
thomwolf committed
654
655

    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
656
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
657
658
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
659
660
661
                f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
                "`PretrainedConfig`. To create a model from a pretrained model use "
                f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
662
            )
663
        # Save config and origin of the pretrained weights if given in model
thomwolf's avatar
thomwolf committed
664
        self.config = config
665
        self.name_or_path = config.name_or_path
thomwolf's avatar
thomwolf committed
666

667
668
669
670
671
672
673
    @classmethod
    def _from_config(cls, config, **kwargs):
        """
        All context managers that the model should be initialized under go here.
        """
        return cls(config, **kwargs)

Julien Plu's avatar
Julien Plu committed
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
    @tf.function(
        input_signature=[
            {
                "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
                "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
                "token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
            }
        ]
    )
    def serving(self, inputs):
        """
        Method used for serving the model.

        Args:
            inputs (:obj:`Dict[str, tf.Tensor]`):
689
                The input of the saved model as a dictionary of tensors.
Julien Plu's avatar
Julien Plu committed
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
        """
        output = self.call(inputs)

        return self.serving_output(output)

    def serving_output(output):
        """
        Prepare the output of the saved model. Each model must implement this function.

        Args:
            output (:obj:`~transformers.TFBaseModelOutput`):
                The output returned by the model.
        """
        raise NotImplementedError

705
    def get_input_embeddings(self) -> tf.keras.layers.Layer:
706
        """
707
        Returns the model's input embeddings layer.
708
709

        Returns:
710
            :obj:`tf.Variable`: The embeddings layer mapping vocabulary to hidden states.
711
        """
712
        main_layer = getattr(self, self.base_model_prefix, self)
Julien Plu's avatar
Julien Plu committed
713

714
715
        if main_layer is not self:
            return main_layer.get_input_embeddings()
716
717
718
        else:
            raise NotImplementedError

Matt's avatar
Matt committed
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
    def compile(
        self,
        optimizer="rmsprop",
        loss="passthrough",
        metrics=None,
        loss_weights=None,
        weighted_metrics=None,
        run_eagerly=None,
        steps_per_execution=None,
        **kwargs
    ):
        """
        This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
        function themselves.
        """
        if loss == "passthrough":
            logger.warning(
                "No loss specified in compile() - the model's internal loss computation will be used as the "
                "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
                "Please ensure your labels are passed as the 'labels' key of the input dict so that they are "
                "accessible to the model during the forward pass. To disable this behaviour, please pass a "
                "loss argument, or explicitly pass loss=None if you do not want your model to compute a loss."
            )
            loss = {"loss": dummy_loss}
        super().compile(
            optimizer=optimizer,
            loss=loss,
            metrics=metrics,
            loss_weights=loss_weights,
            weighted_metrics=weighted_metrics,
            run_eagerly=run_eagerly,
            steps_per_execution=steps_per_execution,
            **kwargs,
        )

    def train_step(self, data):
        """
        A modification of Keras's default train_step that cleans up the printed metrics when we use a dummy loss.
        """
        # These are the only transformations `Model.fit` applies to user-input
        # data when a `tf.data.Dataset` is provided.
        data = data_adapter.expand_1d(data)
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
        # These next two lines differ from the base method - they avoid issues when the labels are in
        # the input dict (and loss is computed internally)
        if y is None and "labels" in x:
            y = x["labels"]  # Stops confusion with metric computations
        # Run forward pass.
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
        # Run backwards pass.
        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
        # Collect metrics to return
        return_metrics = {}
        for metric in self.metrics:
            result = metric.result()
            if isinstance(result, dict):
                return_metrics.update(result)
            else:
                return_metrics[metric.name] = result
        # These next two lines are also not in the base method - they correct the displayed metrics
        # when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
        if "loss" in return_metrics and "loss_loss" in return_metrics:
            del return_metrics["loss_loss"]
        return return_metrics

    def test_step(self, data):
        """
        A modification of Keras's default test_step that cleans up the printed metrics when we use a dummy loss.
        """
        data = data_adapter.expand_1d(data)
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
        # These next two lines differ from the base method - they avoid issues when the labels are in
        # the input dict (and loss is computed internally)
        if y is None and "labels" in x:
            y = x["labels"]  # Stops confusion with metric computations
        y_pred = self(x, training=False)
        if not self.loss:
            self.loss_tracker.update_state(y_pred.loss)
            return_metrics = {"loss": self.loss_tracker.result()}
        else:
            # Run anyway to update state
            self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
            return_metrics = {}
        # Updates stateful loss metrics.
        self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
        # Collect metrics to return
        for metric in self.metrics:
            result = metric.result()
            if isinstance(result, dict):
                return_metrics.update(result)
            else:
                return_metrics[metric.name] = result
        # These next two lines are also not in the base method - they correct the displayed metrics
        # when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
        if "loss" in return_metrics and "loss_loss" in return_metrics:
            del return_metrics["loss_loss"]
        return return_metrics

821
822
    def set_input_embeddings(self, value):
        """
823
        Set model's input embeddings
824
825

        Args:
826
827
            value (:obj:`tf.Variable`):
                The new weights mapping hidden states to vocabulary.
828
        """
829
        main_layer = getattr(self, self.base_model_prefix)
830

831
832
833
834
835
836
837
838
839
840
841
        if main_layer is None:
            raise NotImplementedError("The model does not implements the base_model_prefix attribute.")

        try:
            main_layer.set_input_embeddings(value)
        except AttributeError:
            logger.info("Building the model")
            self(self.dummy_inputs)
            main_layer.set_input_embeddings(value)

    def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
842
        """
843
        Returns the model's output embeddings
844
845

        Returns:
846
            :obj:`tf.Variable`: The new weights mapping vocabulary to hidden states.
847
        """
848
849
850
        if self.get_lm_head() is not None:
            lm_head = self.get_lm_head()

851
852
853
854
855
856
857
            try:
                return lm_head.get_output_embeddings()
            except AttributeError:
                logger.info("Building the model")
                self(self.dummy_inputs)

                return lm_head().get_output_embeddings()
858

859
860
        return None  # Overwrite for models with output embeddings

861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
    def set_output_embeddings(self, value):
        """
        Set model's output embeddings

        Args:
            value (:obj:`tf.Variable`):
                The new weights mapping hidden states to vocabulary.
        """
        if self.get_lm_head() is not None:
            lm_head = self.get_lm_head()
            try:
                lm_head.set_output_embeddings(value)
            except AttributeError:
                logger.info("Building the model")
                self(self.dummy_inputs)
                lm_head.set_output_embeddings(value)

878
879
880
    def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
        """
        Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the
881
        embeddings
882
883
884
885

        Return:
            :obj:`tf.keras.layers.Layer`: The layer that handles the bias, None if not an LM model.
        """
886
887
888
889
        warnings.warn(
            "The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning
        )
        return self.get_lm_head()
890
891
892

    def get_prefix_bias_name(self) -> Union[None, str]:
        """
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
893
        Get the concatenated _prefix name of the bias from the model name to the parent layer
894
895

        Return:
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
896
            :obj:`str`: The _prefix name of the bias.
897
        """
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
        warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
        return None

    def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
        """
        Dict of bias attached to an LM head. The key represents the name of the bias attribute.

        Return:
            :obj:`tf.Variable`: The weights representing the bias, None if not an LM model.
        """
        if self.get_lm_head() is not None:
            lm_head = self.get_lm_head()
            try:
                return lm_head.get_bias()
            except AttributeError:
                self(self.dummy_inputs)

                return lm_head.get_bias()
        return None

    def set_bias(self, value):
        """
        Set all the bias in the LM head.

        Args:
            value (:obj:`Dict[tf.Variable]`):
                All the new bias attached to an LM head.
        """
        if self.get_lm_head() is not None:
            lm_head = self.get_lm_head()
            try:
                lm_head.set_bias(value)
            except AttributeError:
                self(self.dummy_inputs)
                lm_head.set_bias(value)

    def get_lm_head(self) -> tf.keras.layers.Layer:
        """
        The LM Head layer. This method must be overwritten by all the models that have a lm head.

        Return:
            :obj:`tf.keras.layers.Layer`: The LM head layer if the model has one, None if not.
        """
941
942
        return None

943
944
945
    def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
        """
        Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
946

947
        Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
948

949
950
951
952
        Arguments:
            new_num_tokens (:obj:`int`, `optional`):
                The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
                vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
953
                just returns a pointer to the input tokens :obj:`tf.Variable` module of the model without doing
954
955
956
957
                anything.

        Return:
            :obj:`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
958
        """
959
960
        if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
            return self._get_word_embedding_weight(self.get_input_embeddings())
961

962
        model_embeds = self._resize_token_embeddings(new_num_tokens)
963
964
965

        # Update base model and current model config
        self.config.vocab_size = new_num_tokens
966
967
968

        return model_embeds

969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
    def _get_word_embedding_weight(model, embedding_layer):
        embeds = getattr(embedding_layer, "weight", None)
        if embeds is not None:
            return embeds

        embeds = getattr(embedding_layer, "decoder", None)
        if embeds is not None:
            return embeds

        # The reason why the attributes don't exist might be
        # because the model is not built, so retry getting
        # the argument after building the model
        model(model.dummy_inputs)

        embeds = getattr(embedding_layer, "weight", None)
        if embeds is not None:
            return embeds

        embeds = getattr(embedding_layer, "decoder", None)
        if embeds is not None:
            return embeds

        return None
992

993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
    def _resize_token_embeddings(self, new_num_tokens):
        old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)

        # if word embeddings are not tied, make sure that lm head bias is resized as well
        if self.get_bias() is not None:
            old_lm_head_bias = self.get_bias()
            new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)

            self.set_bias(new_lm_head_bias)

        # if word embeddings are not tied, make sure that lm head decoder is resized as well
        if self.get_output_embeddings() is not None:
            old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
            new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)

            self.set_output_embeddings(new_lm_head_decoder)

        self.set_input_embeddings(new_embeddings)

        return self.get_input_embeddings()

    def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens):
1016
        """
1017
1018
        Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
        Reducing the size will remove vectors from the end
thomwolf's avatar
thomwolf committed
1019
1020

        Args:
1021
1022
            old_lm_head_bias (:obj:`tf.Variable`):
                Old lm head bias to be resized.
1023
            new_num_tokens (:obj:`int`, `optional`):
1024
                New number of tokens in the linear matrix.
1025
1026

                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1027
                vectors from the end. If not provided or :obj:`None`, just returns None
1028
1029

        Return:
1030
            :obj:`tf.Variable`: Pointer to the resized bias.
thomwolf's avatar
thomwolf committed
1031
        """
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        new_lm_head_bias = {}

        for attr, weight in old_lm_head_bias.items():
            first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
            size_diff = new_num_tokens - old_num_tokens
            final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens]

            # initialize new bias
            if tf.math.greater(size_diff, 0):
                padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
                current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1)
                num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
                mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy]
                bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True)
                bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False)
            else:
                slice_from = [0] if first_dim is None else [0, 0]
                current_bias = tf.slice(
                    weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape)
                )
                bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True)
1053

1054
1055
1056
1057
1058
1059
1060
            new_bias = self.add_weight(
                shape=final_shape,
                initializer="zeros",
                trainable=True,
                name=weight.name.split(":")[0],
            )
            init_bias = tf.where(bias_mask, current_bias, new_bias.value())
1061

1062
1063
            new_bias.assign(init_bias)
            new_lm_head_bias[attr] = new_bias
1064

1065
        return new_lm_head_bias
thomwolf's avatar
thomwolf committed
1066

1067
1068
1069
1070
    def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens):
        """
        Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end.
        Reducing the size will remove vectors from the end
thomwolf's avatar
thomwolf committed
1071

1072
1073
1074
1075
1076
        Args:
            old_lm_head_decoder (:obj:`tf.Variable`):
                Old lm head decoder to be resized.
            new_num_tokens (:obj:`int`, `optional`):
                New number of tokens in the linear matrix.
1077

1078
1079
                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
                vectors from the end. If not provided or :obj:`None`, just returns None
1080

1081
        Return:
1082
            :obj:`tf.Variable`: Pointer to the resized decoder or None if the output embeddings are different from the
1083
1084
1085
1086
1087
1088
            input ones.
        """
        new_lm_head_decoder = old_lm_head_decoder
        is_input_output_equals = tf.reduce_any(
            self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder
        )
1089

1090
1091
1092
1093
1094
        if old_lm_head_decoder is not None and not is_input_output_equals:
            old_embedding_dim = shape_list(old_lm_head_decoder)[1]
            decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens)
            new_lm_head_decoder = self.add_weight(
                shape=(new_num_tokens, old_embedding_dim),
1095
1096
                initializer="zeros",
                trainable=True,
1097
                name=old_lm_head_decoder.name.split(":")[0],
1098
            )
1099
1100
1101
            init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value())

            new_lm_head_decoder.assign(init_decoder)
1102

1103
        return new_lm_head_decoder
1104

1105
1106
1107
1108
    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
        """
        Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly
        initialized vectors at the end. Reducing the size will remove vectors from the end
1109

1110
1111
1112
1113
1114
        Args:
            old_embeddings (:obj:`tf.Variable`):
                Old embeddings to be resized.
            new_num_tokens (:obj:`int`, `optional`):
                New number of tokens in the embedding matrix.
1115

1116
1117
1118
                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
                vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
                :obj:`tf.Variable`` module of the model without doing anything.
1119

1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
        Return:
            :obj:`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if
            :obj:`new_num_tokens` is :obj:`None`
        """
        old_embedding_dim = shape_list(old_embeddings)[1]
        init_range = getattr(self.config, "initializer_range", 0.02)
        embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens)
        new_embeddings = self.add_weight(
            name=old_embeddings.name.split(":")[0],
            shape=[new_num_tokens, old_embedding_dim],
            initializer=get_initializer(init_range),
            dtype=tf.float32,
        )
        init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value())
1134

1135
        new_embeddings.assign(init_embeddings)
1136

1137
        return new_embeddings
thomwolf's avatar
thomwolf committed
1138
1139

    def prune_heads(self, heads_to_prune):
1140
1141
        """
        Prunes heads of the base model.
thomwolf's avatar
thomwolf committed
1142

1143
1144
        Arguments:
            heads_to_prune (:obj:`Dict[int, List[int]]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1145
1146
1147
                Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list of
                heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads
                0 and 2 on layer 1 and heads 2 and 3 on layer 2.
thomwolf's avatar
thomwolf committed
1148
1149
1150
        """
        raise NotImplementedError

Sylvain Gugger's avatar
Sylvain Gugger committed
1151
    def save_pretrained(self, save_directory, saved_model=False, version=1, push_to_hub=False, **kwargs):
1152
1153
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
Sylvain Gugger's avatar
Sylvain Gugger committed
1154
        :func:`~transformers.TFPreTrainedModel.from_pretrained` class method.
1155
1156
1157
1158

        Arguments:
            save_directory (:obj:`str`):
                Directory to which to save. Will be created if it doesn't exist.
Julien Plu's avatar
Julien Plu committed
1159
1160
1161
1162
1163
1164
            saved_model (:obj:`bool`, `optional`, defaults to :obj:`False`):
                If the model has to be saved in saved model format as well or not.
            version (:obj:`int`, `optional`, defaults to 1):
                The version of the saved model. A saved model needs to be versioned in order to be properly loaded by
                TensorFlow Serving as detailed in the official documentation
                https://www.tensorflow.org/tfx/serving/serving_basic
Sylvain Gugger's avatar
Sylvain Gugger committed
1165
1166
            push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to push your model to the Hugging Face model hub after saving it.
1167
1168
1169
1170
1171
1172
1173
1174

                .. warning::

                    Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
                    :obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
                    pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
                    instead.

Sylvain Gugger's avatar
Sylvain Gugger committed
1175
1176
1177
            kwargs:
                Additional key word arguments passed along to the
                :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
thomwolf's avatar
thomwolf committed
1178
        """
1179
        if os.path.isfile(save_directory):
1180
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
1181
            return
1182
1183
1184
1185
1186

        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            repo = self._create_or_get_repo(save_directory, **kwargs)

1187
        os.makedirs(save_directory, exist_ok=True)
thomwolf's avatar
thomwolf committed
1188

Julien Plu's avatar
Julien Plu committed
1189
1190
1191
1192
1193
        if saved_model:
            saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
            self.save(saved_model_dir, include_optimizer=False, signatures=self.serving)
            logger.info(f"Saved model created in {saved_model_dir}")

thomwolf's avatar
thomwolf committed
1194
        # Save configuration file
1195
        self.config.architectures = [self.__class__.__name__[2:]]
thomwolf's avatar
thomwolf committed
1196
1197
1198
1199
1200
        self.config.save_pretrained(save_directory)

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
        self.save_weights(output_model_file)
1201
        logger.info(f"Model weights saved in {output_model_file}")
thomwolf's avatar
thomwolf committed
1202

Sylvain Gugger's avatar
Sylvain Gugger committed
1203
        if push_to_hub:
1204
            url = self._push_to_hub(repo, commit_message=commit_message)
Sylvain Gugger's avatar
Sylvain Gugger committed
1205
1206
            logger.info(f"Model pushed to the hub in this commit: {url}")

thomwolf's avatar
thomwolf committed
1207
1208
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1209
1210
        r"""
        Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
thomwolf's avatar
thomwolf committed
1211

1212
1213
1214
        The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.
thomwolf's avatar
thomwolf committed
1215

1216
1217
        The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
        weights are discarded.
thomwolf's avatar
thomwolf committed
1218
1219

        Parameters:
1220
1221
1222
            pretrained_model_name_or_path (:obj:`str`, `optional`):
                Can be either:

1223
1224
1225
                    - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
                      Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
                      a user or organization name, like ``dbmdz/bert-base-german-cased``.
1226
                    - A path to a `directory` containing model weights saved using
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
1227
                      :func:`~transformers.TFPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
Sylvain Gugger's avatar
Sylvain Gugger committed
1228
                    - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
1229
1230
1231
1232
1233
1234
1235
                      this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
                      as ``config`` argument. This loading path is slower than converting the PyTorch model in a
                      TensorFlow model using the provided conversion scripts and loading the TensorFlow model
                      afterwards.
                    - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
                      arguments ``config`` and ``state_dict``).
            model_args (sequence of positional arguments, `optional`):
1236
                All remaining positional arguments will be passed to the underlying model's ``__init__`` method.
1237
1238
1239
1240
1241
1242
            config (:obj:`Union[PretrainedConfig, str]`, `optional`):
                Can be either:

                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
                    - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.

1243
                Configuration for the model to use instead of an automatically loaded configuration. Configuration can
1244
1245
                be automatically loaded when:

1246
1247
                    - The model is a model provided by the library (loaded with the `model id` string of a pretrained
                      model).
1248
                    - The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded
1249
1250
                      by supplying the save directory.
                    - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
1251
1252
1253
1254
                      configuration JSON file named `config.json` is found in the directory.
            from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Load the model weights from a PyTorch state_dict save file (see docstring of
                ``pretrained_model_name_or_path`` argument).
qqaatw's avatar
qqaatw committed
1255
            ignore_mismatched_sizes (:obj:`bool`, `optional`, defaults to :obj:`False`):
1256
1257
1258
                Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
                as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
                checkpoint with 3 labels).
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
            cache_dir (:obj:`str`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies: (:obj:`Dict[str, str], `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1269
1270
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1271
            output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1272
                Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
1273
1274
            local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to only look at local files (e.g., not try doanloading the model).
1275
1276
1277
            use_auth_token (:obj:`str` or `bool`, `optional`):
                The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
                generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
Julien Chaumond's avatar
Julien Chaumond committed
1278
1279
1280
1281
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
1282
            mirror(:obj:`str`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1283
1284
1285
                Mirror source to accelerate downloads in China. If you are from China and have an accessibility
                problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
                Please refer to the mirror site for more information.
1286
1287
            kwargs (remaining dictionary of keyword arguments, `optional`):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
1288
                :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
                automatically loaded:

                    - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
                      underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
                      initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
                      ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
                      with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
                      attribute will be passed to the underlying model's ``__init__`` function.
thomwolf's avatar
thomwolf committed
1299

1300
1301
1302
1303
        .. note::

            Passing :obj:`use_auth_token=True` is required when you want to use a private model.

thomwolf's avatar
thomwolf committed
1304
1305
        Examples::

1306
            >>> from transformers import BertConfig, TFBertModel
1307
            >>> # Download model and configuration from huggingface.co and cache.
1308
1309
1310
1311
1312
1313
1314
1315
1316
            >>> model = TFBertModel.from_pretrained('bert-base-uncased')
            >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
            >>> model = TFBertModel.from_pretrained('./test/saved_model/')
            >>> # Update configuration during loading.
            >>> model = TFBertModel.from_pretrained('bert-base-uncased', output_attentions=True)
            >>> assert model.config.output_attentions == True
            >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
            >>> config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json')
            >>> model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config)
thomwolf's avatar
thomwolf committed
1317
1318

        """
1319
1320
1321
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
1322
        ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1323
1324
1325
1326
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
1327
        local_files_only = kwargs.pop("local_files_only", False)
1328
        use_auth_token = kwargs.pop("use_auth_token", None)
Julien Chaumond's avatar
Julien Chaumond committed
1329
        revision = kwargs.pop("revision", None)
1330
        mirror = kwargs.pop("mirror", None)
Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
1331
        load_weight_prefix = kwargs.pop("load_weight_prefix", None)
1332
1333
1334
1335
1336
1337
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)

        user_agent = {"file_type": "model", "framework": "tensorflow", "from_auto_class": from_auto_class}
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline
thomwolf's avatar
thomwolf committed
1338

1339
1340
1341
1342
        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

1343
1344
1345
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
thomwolf's avatar
thomwolf committed
1346
            config, model_kwargs = cls.config_class.from_pretrained(
1347
1348
1349
1350
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
thomwolf's avatar
thomwolf committed
1351
                force_download=force_download,
1352
                resume_download=resume_download,
1353
1354
                proxies=proxies,
                local_files_only=local_files_only,
1355
                use_auth_token=use_auth_token,
Julien Chaumond's avatar
Julien Chaumond committed
1356
                revision=revision,
1357
1358
                _from_auto=from_auto_class,
                _from_pipeline=from_pipeline,
1359
                **kwargs,
thomwolf's avatar
thomwolf committed
1360
1361
1362
1363
1364
            )
        else:
            model_kwargs = kwargs

        # Load model
thomwolf's avatar
thomwolf committed
1365
        if pretrained_model_name_or_path is not None:
1366
            if os.path.isdir(pretrained_model_name_or_path):
1367
1368
1369
1370
                if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint in priority if from_pt
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
thomwolf's avatar
thomwolf committed
1371
1372
1373
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                else:
1374
                    raise EnvironmentError(
1375
1376
                        f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME]} found in directory "
                        f"{pretrained_model_name_or_path} or `from_pt` set to False"
1377
                    )
Julien Chaumond's avatar
Julien Chaumond committed
1378
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
1379
                archive_file = pretrained_model_name_or_path
1380
1381
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"
thomwolf's avatar
thomwolf committed
1382
            else:
thomwolf's avatar
thomwolf committed
1383
                archive_file = hf_bucket_url(
Julien Chaumond's avatar
Julien Chaumond committed
1384
1385
                    pretrained_model_name_or_path,
                    filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
Julien Chaumond's avatar
Julien Chaumond committed
1386
                    revision=revision,
1387
                    mirror=mirror,
thomwolf's avatar
thomwolf committed
1388
                )
thomwolf's avatar
thomwolf committed
1389
1390

            try:
1391
                # Load from URL or cache if already cached
1392
1393
1394
1395
1396
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
1397
1398
                    resume_download=resume_download,
                    local_files_only=local_files_only,
1399
                    use_auth_token=use_auth_token,
1400
                    user_agent=user_agent,
1401
                )
Julien Chaumond's avatar
Julien Chaumond committed
1402
1403
            except EnvironmentError as err:
                logger.error(err)
1404
1405
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
1406
1407
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
                    f"  (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
1408
1409
1410
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)
thomwolf's avatar
thomwolf committed
1411
            if resolved_archive_file == archive_file:
1412
                logger.info(f"loading weights file {archive_file}")
thomwolf's avatar
thomwolf committed
1413
            else:
1414
                logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
thomwolf's avatar
thomwolf committed
1415
        else:
thomwolf's avatar
thomwolf committed
1416
            resolved_archive_file = None
thomwolf's avatar
thomwolf committed
1417

1418
1419
        config.name_or_path = pretrained_model_name_or_path

Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
1420
1421
1422
1423
1424
        # composed models, *e.g.* TFRag, require special treatment when it comes to loading
        # pre-trained weights.
        if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None:
            model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name")

thomwolf's avatar
thomwolf committed
1425
1426
1427
1428
        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
Julien Plu's avatar
Julien Plu committed
1429
1430
            from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model

thomwolf's avatar
thomwolf committed
1431
            # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
1432
            return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
thomwolf's avatar
thomwolf committed
1433

Ratthachat (Jung)'s avatar
Ratthachat (Jung) committed
1434
1435
1436
1437
1438
1439
        # we might need to extend the variable scope for composite models
        if load_weight_prefix is not None:
            with tf.compat.v1.variable_scope(load_weight_prefix):
                model(model.dummy_inputs)  # build the network with dummy inputs
        else:
            model(model.dummy_inputs)  # build the network with dummy inputs
thomwolf's avatar
thomwolf committed
1440

1441
        assert os.path.isfile(resolved_archive_file), f"Error retrieving file {resolved_archive_file}"
thomwolf's avatar
thomwolf committed
1442
1443
        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
1444
        try:
1445
1446
1447
1448
1449
1450
            missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
                model,
                resolved_archive_file,
                ignore_mismatched_sizes=ignore_mismatched_sizes,
                _prefix=load_weight_prefix,
            )
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
        except OSError as e:
            try:
                with open(resolved_archive_file) as f:
                    if f.read().startswith("version"):
                        raise OSError(
                            "You seem to have cloned a repository without having git-lfs installed. Please install "
                            "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                            "you cloned."
                        )
                    else:
                        raise ValueError from e
            except (UnicodeDecodeError, ValueError):
                raise OSError(
                    "Unable to load weights from h5 file. "
                    "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
                )
thomwolf's avatar
thomwolf committed
1467

Julien Plu's avatar
Julien Plu committed
1468
        model(model.dummy_inputs)  # Make sure restore ops are run
thomwolf's avatar
thomwolf committed
1469

1470
1471
        if cls._keys_to_ignore_on_load_missing is not None:
            for pat in cls._keys_to_ignore_on_load_missing:
1472
1473
                missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

1474
1475
        if cls._keys_to_ignore_on_load_unexpected is not None:
            for pat in cls._keys_to_ignore_on_load_unexpected:
Julien Plu's avatar
Julien Plu committed
1476
1477
                unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

1478
1479
        if len(unexpected_keys) > 0:
            logger.warning(
Julien Plu's avatar
Julien Plu committed
1480
                f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when "
1481
1482
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
1483
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
1484
1485
1486
1487
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
Julien Plu's avatar
Julien Plu committed
1488
1489
            logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")

thomwolf's avatar
thomwolf committed
1490
        if len(missing_keys) > 0:
1491
            logger.warning(
Julien Plu's avatar
Julien Plu committed
1492
                f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
1493
1494
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1495
            )
1496
        elif len(mismatched_keys) == 0:
1497
            logger.warning(
Julien Plu's avatar
Julien Plu committed
1498
                f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
1499
                f"If your task is similar to the task the model of the checkpoint was trained on, "
1500
                f"you can already use {model.__class__.__name__} for predictions without further training."
1501
            )
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
        if len(mismatched_keys) > 0:
            mismatched_warning = "\n".join(
                [
                    f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
                    for key, shape1, shape2 in mismatched_keys
                ]
            )
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
Julien Plu's avatar
Julien Plu committed
1514

thomwolf's avatar
thomwolf committed
1515
        if output_loading_info:
1516
1517
1518
1519
1520
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "mismatched_keys": mismatched_keys,
            }
Julien Plu's avatar
Julien Plu committed
1521

thomwolf's avatar
thomwolf committed
1522
1523
            return model, loading_info

thomwolf's avatar
thomwolf committed
1524
        return model
thomwolf's avatar
WIP  
thomwolf committed
1525

1526

1527
1528
1529
1530
1531
1532
1533
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
TFPreTrainedModel.push_to_hub = copy_func(TFPreTrainedModel.push_to_hub)
TFPreTrainedModel.push_to_hub.__doc__ = TFPreTrainedModel.push_to_hub.__doc__.format(
    object="model", object_class="TFAutoModel", object_files="model checkpoint"
)


thomwolf's avatar
WIP  
thomwolf committed
1534
class TFConv1D(tf.keras.layers.Layer):
Sylvain Gugger's avatar
Sylvain Gugger committed
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

    Basically works like a linear layer but the weights are transposed.

    Args:
        nf (:obj:`int`):
            The number of output features.
        nx (:obj:`int`):
            The number of input features.
        initializer_range (:obj:`float`, `optional`, defaults to 0.02):
            The standard deviation to use to initialize the weights.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
    """

thomwolf's avatar
thomwolf committed
1551
    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
1552
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
1553
        self.nf = nf
thomwolf's avatar
thomwolf committed
1554
        self.nx = nx
thomwolf's avatar
thomwolf committed
1555
        self.initializer_range = initializer_range
thomwolf's avatar
thomwolf committed
1556
1557
1558

    def build(self, input_shape):
        self.weight = self.add_weight(
1559
1560
1561
            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
        )
        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
thomwolf's avatar
thomwolf committed
1562

thomwolf's avatar
WIP  
thomwolf committed
1563
    def call(self, x):
thomwolf's avatar
thomwolf committed
1564
        bz, sl = shape_list(x)[:2]
thomwolf's avatar
thomwolf committed
1565

thomwolf's avatar
thomwolf committed
1566
        x = tf.reshape(x, [-1, self.nx])
thomwolf's avatar
thomwolf committed
1567
        x = tf.matmul(x, self.weight) + self.bias
thomwolf's avatar
thomwolf committed
1568
1569

        x = tf.reshape(x, [bz, sl, self.nf])
thomwolf's avatar
thomwolf committed
1570

thomwolf's avatar
WIP  
thomwolf committed
1571
        return x
thomwolf's avatar
thomwolf committed
1572
1573


thomwolf's avatar
thomwolf committed
1574
class TFSharedEmbeddings(tf.keras.layers.Layer):
Stas Bekman's avatar
Stas Bekman committed
1575
    r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1576
    Construct shared token embeddings.
1577

Sylvain Gugger's avatar
Sylvain Gugger committed
1578
1579
    The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language
    modeling.
Sylvain Gugger's avatar
Sylvain Gugger committed
1580
1581
1582

    Args:
        vocab_size (:obj:`int`):
1583
            The size of the vocabulary, e.g., the number of unique tokens.
Sylvain Gugger's avatar
Sylvain Gugger committed
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
        hidden_size (:obj:`int`):
            The size of the embedding vectors.
        initializer_range (:obj:`float`, `optional`):
            The standard deviation to use when initializing the weights. If no value is provided, it will default to
            :math:`1/\sqrt{hidden\_size}`.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
    """

    def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
1594
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
1595
1596
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
1597
        self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
thomwolf's avatar
thomwolf committed
1598
1599

    def build(self, input_shape):
Sylvain Gugger's avatar
Sylvain Gugger committed
1600
1601
1602
        """
        Build shared token embedding layer Shared weights logic adapted from
        https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
thomwolf's avatar
thomwolf committed
1603
1604
        """
        self.weight = self.add_weight(
1605
1606
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
Julien Chaumond's avatar
Julien Chaumond committed
1607
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
1608

Julien Plu's avatar
Julien Plu committed
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
    def get_config(self):
        config = {
            "vocab_size": self.vocab_size,
            "hidden_size": self.hidden_size,
            "initializer_range": self.initializer_range,
        }
        base_config = super().get_config()

        return dict(list(base_config.items()) + list(config.items()))

Sylvain Gugger's avatar
Sylvain Gugger committed
1619
1620
1621
1622
    def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
        """
        Get token embeddings of inputs or decode final hidden state.

thomwolf's avatar
thomwolf committed
1623
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1624
1625
1626
1627
1628
1629
1630
1631
            inputs (:obj:`tf.Tensor`):
                In embedding mode, should be an int64 tensor with shape :obj:`[batch_size, length]`.

                In linear mode, should be a float tensor with shape :obj:`[batch_size, length, hidden_size]`.
            mode (:obj:`str`, defaults to :obj:`"embedding"`):
               A valid value is either :obj:`"embedding"` or :obj:`"linear"`, the first one indicates that the layer
               should be used as an embedding layer, the second one that the layer should be used as a linear decoder.

thomwolf's avatar
thomwolf committed
1632
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1633
            :obj:`tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape
Sylvain Gugger's avatar
Sylvain Gugger committed
1634
1635
            :obj:`[batch_size, length, embedding_size]`.

1636
            In linear mode, the output is a float32 with shape :obj:`[batch_size, length, vocab_size]`.
Sylvain Gugger's avatar
Sylvain Gugger committed
1637

thomwolf's avatar
thomwolf committed
1638
        Raises:
Sylvain Gugger's avatar
Sylvain Gugger committed
1639
            ValueError: if :obj:`mode` is not valid.
1640

Sylvain Gugger's avatar
Sylvain Gugger committed
1641
1642
        Shared weights logic is adapted from `here
        <https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24>`__.
thomwolf's avatar
thomwolf committed
1643
1644
1645
1646
1647
1648
        """
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
1649
            raise ValueError(f"mode {mode} is not valid.")
thomwolf's avatar
thomwolf committed
1650
1651
1652
1653
1654
1655
1656

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """
Julien Plu's avatar
Julien Plu committed
1657
        Computes logits by running inputs through a linear layer.
thomwolf's avatar
thomwolf committed
1658

Julien Plu's avatar
Julien Plu committed
1659
1660
1661
1662
1663
1664
1665
        Args:
            inputs: A float32 tensor with shape [..., hidden_size]

        Returns:
            float32 tensor with shape [..., vocab_size].
        """
        first_dims = shape_list(inputs)[:-1]
thomwolf's avatar
thomwolf committed
1666
1667
1668
1669
1670
1671
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


thomwolf's avatar
thomwolf committed
1672
class TFSequenceSummary(tf.keras.layers.Layer):
Julien Plu's avatar
Julien Plu committed
1673
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1674
1675
1676
1677
    Compute a single vector summary of a sequence hidden states.

    Args:
        config (:class:`~transformers.PretrainedConfig`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1678
1679
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):
Sylvain Gugger's avatar
Sylvain Gugger committed
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691

            - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:

                - :obj:`"last"` -- Take the last token hidden state (like XLNet)
                - :obj:`"first"` -- Take the first token hidden state (like Bert)
                - :obj:`"mean"` -- Take the mean of all tokens hidden states
                - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - :obj:`"attn"` -- Not implemented now, use multi-head attention

            - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
              :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1692
            - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
Sylvain Gugger's avatar
Sylvain Gugger committed
1693
1694
1695
1696
1697
1698
1699
1700
1701
              output, another string or :obj:`None` will add no activation.
            - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
              activation.
            - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
              activation.

        initializer_range (:obj:`float`, defaults to 0.02): The standard deviation to use to initialize the weights.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
thomwolf's avatar
thomwolf committed
1702
    """
1703

Sylvain Gugger's avatar
Sylvain Gugger committed
1704
    def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
1705
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
1706

1707
1708
        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
1709
1710
1711
1712
1713
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

1714
        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
1715
        if self.has_summary:
1716
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
thomwolf's avatar
thomwolf committed
1717
1718
1719
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
1720
1721
1722
            self.summary = tf.keras.layers.Dense(
                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
            )
thomwolf's avatar
thomwolf committed
1723

1724
        self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
1725
        if self.has_activation:
1726
            self.activation = tf.keras.activations.tanh
thomwolf's avatar
thomwolf committed
1727

1728
        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
1729
        if self.has_first_dropout:
thomwolf's avatar
thomwolf committed
1730
1731
            self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)

1732
        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
1733
        if self.has_last_dropout:
thomwolf's avatar
thomwolf committed
1734
1735
            self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)

Julien Plu's avatar
Julien Plu committed
1736
    def call(self, inputs, cls_index=None, training=False):
thomwolf's avatar
thomwolf committed
1737
1738
1739
1740
1741
1742
1743
        if not isinstance(inputs, (dict, tuple, list)):
            hidden_states = inputs
        elif isinstance(inputs, (tuple, list)):
            hidden_states = inputs[0]
            cls_index = inputs[1] if len(inputs) > 1 else None
            assert len(inputs) <= 2, "Too many inputs."
        else:
1744
            hidden_states = inputs.get("hidden_states")
1745
            cls_index = inputs.get("cls_index", None)
thomwolf's avatar
thomwolf committed
1746

1747
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
1748
            output = hidden_states[:, -1]
1749
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
1750
            output = hidden_states[:, 0]
1751
        elif self.summary_type == "mean":
Lysandre's avatar
Lysandre committed
1752
            output = tf.reduce_mean(hidden_states, axis=1)
1753
        elif self.summary_type == "cls_index":
1754
            hidden_shape = shape_list(hidden_states)  # e.g. [batch, num choices, seq length, hidden dims]
thomwolf's avatar
thomwolf committed
1755
            if cls_index is None:
1756
1757
1758
                cls_index = tf.fill(
                    hidden_shape[:-2], hidden_shape[-2] - 1
                )  # A tensor full of shape [batch] or [batch, num choices] full of sequence length
1759
1760
            cls_shape = shape_list(cls_index)
            if len(cls_shape) <= len(hidden_shape) - 2:
1761
                cls_index = tf.expand_dims(cls_index, axis=-1)
1762
            # else:
1763
1764
            # cls_index = cls_index[..., tf.newaxis]
            # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
1765
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1766
            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
1767
1768
1769
1770
            output = tf.squeeze(
                output, axis=len(hidden_shape) - 2
            )  # shape of output: (batch, num choices, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
1771
1772
            raise NotImplementedError

1773
1774
        if self.has_first_dropout:
            output = self.first_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1775

1776
        if self.has_summary:
1777
            output = self.summary(output)
thomwolf's avatar
thomwolf committed
1778

1779
        if self.has_activation:
thomwolf's avatar
thomwolf committed
1780
1781
            output = self.activation(output)

1782
1783
        if self.has_last_dropout:
            output = self.last_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1784
1785
1786

        return output

1787

Julien Plu's avatar
Julien Plu committed
1788
def shape_list(tensor: tf.Tensor) -> List[int]:
Sylvain Gugger's avatar
Sylvain Gugger committed
1789
1790
1791
1792
    """
    Deal with dynamic shape in tensorflow cleanly.

    Args:
Julien Plu's avatar
Julien Plu committed
1793
        tensor (:obj:`tf.Tensor`): The tensor we want the shape of.
Sylvain Gugger's avatar
Sylvain Gugger committed
1794
1795
1796
1797

    Returns:
        :obj:`List[int]`: The shape of the tensor as a list.
    """
Julien Plu's avatar
Julien Plu committed
1798
    dynamic = tf.shape(tensor)
Julien Plu's avatar
Julien Plu committed
1799
1800

    if tensor.shape == tf.TensorShape(None):
1801
        return dynamic
Julien Plu's avatar
Julien Plu committed
1802
1803
1804

    static = tensor.shape.as_list()

thomwolf's avatar
thomwolf committed
1805
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
thomwolf's avatar
thomwolf committed
1806

1807

Sylvain Gugger's avatar
Sylvain Gugger committed
1808
1809
1810
1811
def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal:
    """
    Creates a :obj:`tf.initializers.TruncatedNormal` with the given range.

Julien Chaumond's avatar
Julien Chaumond committed
1812
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1813
1814
        initializer_range (`float`, defaults to 0.02): Standard deviation of the initializer range.

Julien Chaumond's avatar
Julien Chaumond committed
1815
    Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1816
        :obj:`tf.initializers.TruncatedNormal`: The truncated normal initializer.
Julien Chaumond's avatar
Julien Chaumond committed
1817
1818
    """
    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
1819
1820


Sam Shleifer's avatar
Sam Shleifer committed
1821
1822
class TFWrappedEmbeddings:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1823
1824
1825
    this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
    weight restoring. Also it makes sure that the layer is called from the correct scope to avoid problem with
    saving/storing the correct weights
Sam Shleifer's avatar
Sam Shleifer committed
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
    """

    def __init__(self, layer, abs_scope_name=None):
        self._layer = layer
        self._abs_scope_name = abs_scope_name

    def call(self, inputs, mode="embedding"):
        if self._abs_scope_name is None:
            return self._layer.call(inputs, mode)

        # if an abs scope name is given to the embedding variable, call variable from absolute scope
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
            with tf.name_scope(abs_scope_name.original_name_scope):
                return self._layer.call(inputs, mode)

    def __call__(self, inputs, mode="embedding"):
        if self._abs_scope_name is None:
            return self._layer(inputs, mode)

        # if an abs scope name is given to the embedding variable, call variable from absolute scope
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
            with tf.name_scope(abs_scope_name.original_name_scope):
                return self._layer(inputs, mode)