generic.py 23.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generic utilities
"""

18
import inspect
19
import tempfile
20
from collections import OrderedDict, UserDict
21
from collections.abc import MutableMapping
22
from contextlib import ExitStack, contextmanager
23
from dataclasses import fields, is_dataclass
24
from enum import Enum
25
from functools import partial
26
from typing import Any, ContextManager, Iterable, List, Tuple
27
28

import numpy as np
29
from packaging import version
30

31
32
33
34
35
36
37
38
from .import_utils import (
    get_torch_version,
    is_flax_available,
    is_mlx_available,
    is_tf_available,
    is_torch_available,
    is_torch_fx_proxy,
)
39
40


41
42
43
44
if is_flax_available():
    import jax.numpy as jnp


45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class cached_property(property):
    """
    Descriptor that mimics @property but caches output in member variable.

    From tensorflow_datasets

    Built-in in functools from Python 3.8.
    """

    def __get__(self, obj, objtype=None):
        # See docs.python.org/3/howto/descriptor.html#properties
        if obj is None:
            return self
        if self.fget is None:
            raise AttributeError("unreadable attribute")
        attr = "__cached_" + self.fget.__name__
        cached = getattr(obj, attr, None)
        if cached is None:
            cached = self.fget(obj)
            setattr(obj, attr, cached)
        return cached


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# vendored from distutils.util
def strtobool(val):
    """Convert a string representation of truth to true (1) or false (0).

    True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
    Raises ValueError if 'val' is anything else.
    """
    val = val.lower()
    if val in {"y", "yes", "t", "true", "on", "1"}:
        return 1
    if val in {"n", "no", "f", "false", "off", "0"}:
        return 0
    raise ValueError(f"invalid truth value {val!r}")


83
def infer_framework_from_repr(x):
84
    """
85
86
    Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
    frameworks in a smart order, without the need to import the frameworks).
87
    """
Yih-Dar's avatar
Yih-Dar committed
88
89
    representation = str(type(x))
    if representation.startswith("<class 'torch."):
90
        return "pt"
Yih-Dar's avatar
Yih-Dar committed
91
    elif representation.startswith("<class 'tensorflow."):
92
        return "tf"
Yih-Dar's avatar
Yih-Dar committed
93
    elif representation.startswith("<class 'jax"):
94
        return "jax"
Yih-Dar's avatar
Yih-Dar committed
95
    elif representation.startswith("<class 'numpy."):
96
        return "np"
97
98
    elif representation.startswith("<class 'mlx."):
        return "mlx"
99
100
101
102
103
104
105
106
107
108
109
110


def _get_frameworks_and_test_func(x):
    """
    Returns an (ordered since we are in Python 3.7+) dictionary framework to test function, which places the framework
    we can guess from the repr first, then Numpy, then the others.
    """
    framework_to_test = {
        "pt": is_torch_tensor,
        "tf": is_tf_tensor,
        "jax": is_jax_tensor,
        "np": is_numpy_array,
111
        "mlx": is_mlx_array,
112
113
114
115
116
117
118
119
    }
    preferred_framework = infer_framework_from_repr(x)
    # We will test this one first, then numpy, then the others.
    frameworks = [] if preferred_framework is None else [preferred_framework]
    if preferred_framework != "np":
        frameworks.append("np")
    frameworks.extend([f for f in framework_to_test if f not in [preferred_framework, "np"]])
    return {f: framework_to_test[f] for f in frameworks}
120
121


122
123
def is_tensor(x):
    """
124
125
    Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray`, `np.ndarray` or `mlx.array`
    in the order defined by `infer_framework_from_repr`
126
127
128
129
130
    """
    # This gives us a smart order to test the frameworks with the corresponding tests.
    framework_to_test_func = _get_frameworks_and_test_func(x)
    for test_func in framework_to_test_func.values():
        if test_func(x):
131
132
            return True

133
134
135
136
    # Tracers
    if is_torch_fx_proxy(x):
        return True

137
138
139
    if is_flax_available():
        from jax.core import Tracer

140
        if isinstance(x, Tracer):
141
142
            return True

143
    return False
144
145
146
147
148
149


def _is_numpy(x):
    return isinstance(x, np.ndarray)


150
151
152
153
154
155
156
def is_numpy_array(x):
    """
    Tests if `x` is a numpy array or not.
    """
    return _is_numpy(x)


157
158
159
160
161
162
def _is_torch(x):
    import torch

    return isinstance(x, torch.Tensor)


163
164
165
166
167
168
169
def is_torch_tensor(x):
    """
    Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
    """
    return False if not is_torch_available() else _is_torch(x)


170
171
172
173
174
175
def _is_torch_device(x):
    import torch

    return isinstance(x, torch.device)


176
177
178
179
180
181
182
def is_torch_device(x):
    """
    Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
    """
    return False if not is_torch_available() else _is_torch_device(x)


183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def _is_torch_dtype(x):
    import torch

    if isinstance(x, str):
        if hasattr(torch, x):
            x = getattr(torch, x)
        else:
            return False
    return isinstance(x, torch.dtype)


def is_torch_dtype(x):
    """
    Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed.
    """
    return False if not is_torch_available() else _is_torch_dtype(x)


201
202
203
204
205
206
def _is_tensorflow(x):
    import tensorflow as tf

    return isinstance(x, tf.Tensor)


207
208
209
210
211
212
213
def is_tf_tensor(x):
    """
    Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed.
    """
    return False if not is_tf_available() else _is_tensorflow(x)


214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def _is_tf_symbolic_tensor(x):
    import tensorflow as tf

    # the `is_symbolic_tensor` predicate is only available starting with TF 2.14
    if hasattr(tf, "is_symbolic_tensor"):
        return tf.is_symbolic_tensor(x)
    return type(x) == tf.Tensor


def is_tf_symbolic_tensor(x):
    """
    Tests if `x` is a tensorflow symbolic tensor or not (ie. not eager). Safe to call even if tensorflow is not
    installed.
    """
    return False if not is_tf_available() else _is_tf_symbolic_tensor(x)


231
232
233
234
235
236
def _is_jax(x):
    import jax.numpy as jnp  # noqa: F811

    return isinstance(x, jnp.ndarray)


237
238
239
240
241
242
243
def is_jax_tensor(x):
    """
    Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed.
    """
    return False if not is_flax_available() else _is_jax(x)


244
def _is_mlx(x):
245
    import mlx.core as mx
246
247
248
249
250
251
252
253
254
255
256

    return isinstance(x, mx.array)


def is_mlx_array(x):
    """
    Tests if `x` is a mlx array or not. Safe to call even when mlx is not installed.
    """
    return False if not is_mlx_available() else _is_mlx(x)


257
258
259
260
def to_py_obj(obj):
    """
    Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
    """
261
262
263
264
265
266
267
268

    framework_to_py_obj = {
        "pt": lambda obj: obj.detach().cpu().tolist(),
        "tf": lambda obj: obj.numpy().tolist(),
        "jax": lambda obj: np.asarray(obj).tolist(),
        "np": lambda obj: obj.tolist(),
    }

269
270
271
272
    if isinstance(obj, (dict, UserDict)):
        return {k: to_py_obj(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [to_py_obj(o) for o in obj]
273
274
275
276
277
278
279
280
281

    # This gives us a smart order to test the frameworks with the corresponding tests.
    framework_to_test_func = _get_frameworks_and_test_func(obj)
    for framework, test_func in framework_to_test_func.items():
        if test_func(obj):
            return framework_to_py_obj[framework](obj)

    # tolist also works on 0d np arrays
    if isinstance(obj, np.number):
282
283
284
285
286
287
288
289
290
        return obj.tolist()
    else:
        return obj


def to_numpy(obj):
    """
    Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a Numpy array.
    """
291
292
293
294
295
296
297
298

    framework_to_numpy = {
        "pt": lambda obj: obj.detach().cpu().numpy(),
        "tf": lambda obj: obj.numpy(),
        "jax": lambda obj: np.asarray(obj),
        "np": lambda obj: obj,
    }

299
300
301
302
    if isinstance(obj, (dict, UserDict)):
        return {k: to_numpy(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return np.array(obj)
303
304
305
306
307
308
309
310

    # This gives us a smart order to test the frameworks with the corresponding tests.
    framework_to_test_func = _get_frameworks_and_test_func(obj)
    for framework, test_func in framework_to_test_func.items():
        if test_func(obj):
            return framework_to_numpy[framework](obj)

    return obj
311
312
313
314
315
316
317
318
319
320


class ModelOutput(OrderedDict):
    """
    Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
    tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
    python dictionary.

    <Tip warning={true}>

321
322
    You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple
    before.
323
324
325
326

    </Tip>
    """

327
328
329
330
331
332
333
    def __init_subclass__(cls) -> None:
        """Register subclasses as pytree nodes.

        This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
        `static_graph=True` with modules that output `ModelOutput` subclasses.
        """
        if is_torch_available():
334
335
336
337
338
339
340
341
342
343
344
345
346
            if version.parse(get_torch_version()) >= version.parse("2.2"):
                _torch_pytree.register_pytree_node(
                    cls,
                    _model_output_flatten,
                    partial(_model_output_unflatten, output_type=cls),
                    serialized_type_name=f"{cls.__module__}.{cls.__name__}",
                )
            else:
                _torch_pytree._register_pytree_node(
                    cls,
                    _model_output_flatten,
                    partial(_model_output_unflatten, output_type=cls),
                )
347

348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Subclasses of ModelOutput must use the @dataclass decorator
        # This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
        # issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
        # Just need to check that the current class is not ModelOutput
        is_modeloutput_subclass = self.__class__ != ModelOutput

        if is_modeloutput_subclass and not is_dataclass(self):
            raise TypeError(
                f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
                " This is a subclass of ModelOutput and so must use the @dataclass decorator."
            )

363
    def __post_init__(self):
364
365
366
367
        """Check the ModelOutput dataclass.

        Only occurs if @dataclass decorator has been used.
        """
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        class_fields = fields(self)

        # Safety and consistency checks
        if not len(class_fields):
            raise ValueError(f"{self.__class__.__name__} has no fields.")
        if not all(field.default is None for field in class_fields[1:]):
            raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")

        first_field = getattr(self, class_fields[0].name)
        other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])

        if other_fields_are_none and not is_tensor(first_field):
            if isinstance(first_field, dict):
                iterator = first_field.items()
                first_field_iterator = True
            else:
                try:
                    iterator = iter(first_field)
                    first_field_iterator = True
                except TypeError:
                    first_field_iterator = False

            # if we provided an iterator as first field and the iterator is a (key, value) iterator
            # set the associated fields
            if first_field_iterator:
393
                for idx, element in enumerate(iterator):
394
395
396
397
398
                    if (
                        not isinstance(element, (list, tuple))
                        or not len(element) == 2
                        or not isinstance(element[0], str)
                    ):
399
400
401
402
403
404
405
406
                        if idx == 0:
                            # If we do not have an iterator of key/values, set it as attribute
                            self[class_fields[0].name] = first_field
                        else:
                            # If we have a mixed iterator, raise an error
                            raise ValueError(
                                f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
                            )
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
                        break
                    setattr(self, element[0], element[1])
                    if element[1] is not None:
                        self[element[0]] = element[1]
            elif first_field is not None:
                self[class_fields[0].name] = first_field
        else:
            for field in class_fields:
                v = getattr(self, field.name)
                if v is not None:
                    self[field.name] = v

    def __delitem__(self, *args, **kwargs):
        raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")

    def setdefault(self, *args, **kwargs):
        raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")

    def pop(self, *args, **kwargs):
        raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")

    def update(self, *args, **kwargs):
        raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")

    def __getitem__(self, k):
        if isinstance(k, str):
Sylvain's avatar
Sylvain committed
433
            inner_dict = dict(self.items())
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
            return inner_dict[k]
        else:
            return self.to_tuple()[k]

    def __setattr__(self, name, value):
        if name in self.keys() and value is not None:
            # Don't call self.__setitem__ to avoid recursion errors
            super().__setitem__(name, value)
        super().__setattr__(name, value)

    def __setitem__(self, key, value):
        # Will raise a KeyException if needed
        super().__setitem__(key, value)
        # Don't call self.__setattr__ to avoid recursion errors
        super().__setattr__(key, value)

450
451
452
453
454
455
456
    def __reduce__(self):
        if not is_dataclass(self):
            return super().__reduce__()
        callable, _args, *remaining = super().__reduce__()
        args = tuple(getattr(self, field.name) for field in fields(self))
        return callable, args, *remaining

457
458
459
460
461
462
463
    def to_tuple(self) -> Tuple[Any]:
        """
        Convert self to a tuple containing all the attributes/keys that are not `None`.
        """
        return tuple(self[k] for k in self.keys())


464
465
466
467
if is_torch_available():
    import torch.utils._pytree as _torch_pytree

    def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        return list(output.values()), list(output.keys())

    def _model_output_unflatten(
        values: Iterable[Any],
        context: "_torch_pytree.Context",
        output_type=None,
    ) -> ModelOutput:
        return output_type(**dict(zip(context, values)))

    if version.parse(get_torch_version()) >= version.parse("2.2"):
        _torch_pytree.register_pytree_node(
            ModelOutput,
            _model_output_flatten,
            partial(_model_output_unflatten, output_type=ModelOutput),
            serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
        )
484
    else:
485
486
487
488
489
        _torch_pytree._register_pytree_node(
            ModelOutput,
            _model_output_flatten,
            partial(_model_output_unflatten, output_type=ModelOutput),
        )
490
491


492
class ExplicitEnum(str, Enum):
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
    """
    Enum with more explicit error message for missing values.
    """

    @classmethod
    def _missing_(cls, value):
        raise ValueError(
            f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
        )


class PaddingStrategy(ExplicitEnum):
    """
    Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
    IDE.
    """

    LONGEST = "longest"
    MAX_LENGTH = "max_length"
    DO_NOT_PAD = "do_not_pad"


class TensorType(ExplicitEnum):
    """
    Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
    tab-completion in an IDE.
    """

    PYTORCH = "pt"
    TENSORFLOW = "tf"
    NUMPY = "np"
    JAX = "jax"
525
    MLX = "mlx"
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543


class ContextManagers:
    """
    Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
    in the `fastcore` library.
    """

    def __init__(self, context_managers: List[ContextManager]):
        self.context_managers = context_managers
        self.stack = ExitStack()

    def __enter__(self):
        for context_manager in self.context_managers:
            self.stack.enter_context(context_manager)

    def __exit__(self, *args, **kwargs):
        self.stack.__exit__(*args, **kwargs)
544
545


546
547
548
549
550
551
552
def can_return_loss(model_class):
    """
    Check if a given model can return loss.

    Args:
        model_class (`type`): The class of the model.
    """
Matt's avatar
Matt committed
553
554
    framework = infer_framework(model_class)
    if framework == "tf":
555
        signature = inspect.signature(model_class.call)  # TensorFlow models
Matt's avatar
Matt committed
556
    elif framework == "pt":
557
        signature = inspect.signature(model_class.forward)  # PyTorch models
558
    else:
559
        signature = inspect.signature(model_class.__call__)  # Flax models
560
561
562
563
564
565
566
567

    for p in signature.parameters:
        if p == "return_loss" and signature.parameters[p].default is True:
            return True

    return False


568
569
570
571
572
573
574
575
def find_labels(model_class):
    """
    Find the labels used by a given model.

    Args:
        model_class (`type`): The class of the model.
    """
    model_name = model_class.__name__
Matt's avatar
Matt committed
576
577
    framework = infer_framework(model_class)
    if framework == "tf":
578
        signature = inspect.signature(model_class.call)  # TensorFlow models
Matt's avatar
Matt committed
579
    elif framework == "pt":
580
        signature = inspect.signature(model_class.forward)  # PyTorch models
581
    else:
582
583
        signature = inspect.signature(model_class.__call__)  # Flax models

584
585
586
587
    if "QuestionAnswering" in model_name:
        return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
    else:
        return [p for p in signature.parameters if "label" in p]
588
589
590
591
592
593
594
595
596
597
598
599
600
601


def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
    """Flatten a nested dict into a single level dict."""

    def _flatten_dict(d, parent_key="", delimiter="."):
        for k, v in d.items():
            key = str(parent_key) + delimiter + str(k) if parent_key else k
            if v and isinstance(v, MutableMapping):
                yield from flatten_dict(v, key, delimiter=delimiter).items()
            else:
                yield key, v

    return dict(_flatten_dict(d, parent_key, delimiter))
602
603
604
605
606
607
608
609
610


@contextmanager
def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
    if use_temp_dir:
        with tempfile.TemporaryDirectory() as tmp_dir:
            yield tmp_dir
    else:
        yield working_dir
611
612
613
614
615
616
617
618
619
620
621
622


def transpose(array, axes=None):
    """
    Framework-agnostic version of `numpy.transpose` that will work on torch/TensorFlow/Jax tensors as well as NumPy
    arrays.
    """
    if is_numpy_array(array):
        return np.transpose(array, axes=axes)
    elif is_torch_tensor(array):
        return array.T if axes is None else array.permute(*axes)
    elif is_tf_tensor(array):
623
624
        import tensorflow as tf

625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        return tf.transpose(array, perm=axes)
    elif is_jax_tensor(array):
        return jnp.transpose(array, axes=axes)
    else:
        raise ValueError(f"Type not supported for transpose: {type(array)}.")


def reshape(array, newshape):
    """
    Framework-agnostic version of `numpy.reshape` that will work on torch/TensorFlow/Jax tensors as well as NumPy
    arrays.
    """
    if is_numpy_array(array):
        return np.reshape(array, newshape)
    elif is_torch_tensor(array):
        return array.reshape(*newshape)
    elif is_tf_tensor(array):
642
643
        import tensorflow as tf

644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        return tf.reshape(array, newshape)
    elif is_jax_tensor(array):
        return jnp.reshape(array, newshape)
    else:
        raise ValueError(f"Type not supported for reshape: {type(array)}.")


def squeeze(array, axis=None):
    """
    Framework-agnostic version of `numpy.squeeze` that will work on torch/TensorFlow/Jax tensors as well as NumPy
    arrays.
    """
    if is_numpy_array(array):
        return np.squeeze(array, axis=axis)
    elif is_torch_tensor(array):
        return array.squeeze() if axis is None else array.squeeze(dim=axis)
    elif is_tf_tensor(array):
661
662
        import tensorflow as tf

663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
        return tf.squeeze(array, axis=axis)
    elif is_jax_tensor(array):
        return jnp.squeeze(array, axis=axis)
    else:
        raise ValueError(f"Type not supported for squeeze: {type(array)}.")


def expand_dims(array, axis):
    """
    Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy
    arrays.
    """
    if is_numpy_array(array):
        return np.expand_dims(array, axis)
    elif is_torch_tensor(array):
        return array.unsqueeze(dim=axis)
    elif is_tf_tensor(array):
680
681
        import tensorflow as tf

682
683
684
685
686
        return tf.expand_dims(array, axis=axis)
    elif is_jax_tensor(array):
        return jnp.expand_dims(array, axis=axis)
    else:
        raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
687
688
689
690
691
692
693
694
695
696
697


def tensor_size(array):
    """
    Framework-agnostic version of `numpy.size` that will work on torch/TensorFlow/Jax tensors as well as NumPy arrays.
    """
    if is_numpy_array(array):
        return np.size(array)
    elif is_torch_tensor(array):
        return array.numel()
    elif is_tf_tensor(array):
698
699
        import tensorflow as tf

Sylvain Gugger's avatar
Sylvain Gugger committed
700
701
702
703
        return tf.size(array)
    elif is_jax_tensor(array):
        return array.size
    else:
704
        raise ValueError(f"Type not supported for tensor_size: {type(array)}.")
705
706
707
708
709
710
711
712


def add_model_info_to_auto_map(auto_map, repo_id):
    """
    Adds the information of the repo_id to a given auto map.
    """
    for key, value in auto_map.items():
        if isinstance(value, (tuple, list)):
713
714
715
            auto_map[key] = [f"{repo_id}--{v}" if (v is not None and "--" not in v) else v for v in value]
        elif value is not None and "--" not in value:
            auto_map[key] = f"{repo_id}--{value}"
716
717

    return auto_map
Matt's avatar
Matt committed
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735


def infer_framework(model_class):
    """
    Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant
    classes are imported or available.
    """
    for base_class in inspect.getmro(model_class):
        module = base_class.__module__
        name = base_class.__name__
        if module.startswith("tensorflow") or module.startswith("keras") or name == "TFPreTrainedModel":
            return "tf"
        elif module.startswith("torch") or name == "PreTrainedModel":
            return "pt"
        elif module.startswith("flax") or module.startswith("jax") or name == "FlaxPreTrainedModel":
            return "flax"
    else:
        raise TypeError(f"Could not infer framework from class {model_class}.")