"tests/models/vscode:/vscode.git/clone" did not exist on "ae1cffaf3cd42d0ab1d7529e3b3118725bca0bcf"
fx.py 23.1 KB
Newer Older
1
2
import copy
import functools
3
import inspect
4
5
import random
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
6
7

import torch
8
from packaging import version
9
from torch import nn
10
11
12
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
from torch.fx.node import Argument

13
from .. import (
14
    CONFIG_MAPPING,
15
16
17
18
19
20
21
22
23
24
    MODEL_FOR_CAUSAL_LM_MAPPING,
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
    MODEL_FOR_MASKED_LM_MAPPING,
    MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
    MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
    MODEL_FOR_PRETRAINING_MAPPING,
    MODEL_FOR_QUESTION_ANSWERING_MAPPING,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
    MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
25
    MODEL_MAPPING,
26
    GPT2DoubleHeadsModel,
27
    PretrainedConfig,
28
29
30
    PreTrainedModel,
    logging,
)
31
from ..file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available
32
from ..models.auto import get_values
33
34
35
36
37
38
39
from .fx_transformations import (
    _cache_attributes,
    _patch_arguments_,
    _restore_attributes_,
    transform_to_dynamic_input_,
    transformation,
)
40
41
42


logger = logging.get_logger(__name__)
43
44


45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def _generate_supported_model_classes(
    model_name: Type[PretrainedConfig],
    supported_tasks: Optional[Union[str, List[str]]] = None,
) -> List[Type[PreTrainedModel]]:
    model_config_class = CONFIG_MAPPING[model_name]
    task_mapping = {
        "default": MODEL_MAPPING,
        "pretraining": MODEL_FOR_PRETRAINING_MAPPING,
        "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
        "masked-lm": MODEL_FOR_MASKED_LM_MAPPING,
        "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING,
        "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
        "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
        "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
        "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
    }

    if supported_tasks is None:
        supported_tasks = task_mapping.keys()
    if isinstance(supported_tasks, str):
        supported_tasks = [supported_tasks]

    model_classes = []
    for task in supported_tasks:
        model_class = task_mapping[task].get(model_config_class, None)
        if model_class:
            model_classes.append(model_class)

    return model_classes


_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
    "albert",
    "bert",
    "distilbert",
    "mobilebert",
    "electra",
    "megatron-bert",
    "gpt2",
    "gptj",
    "gpt_neo",
    "t5",
]

_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES = [
    "albert",
    "bert",
    "distilbert",
    "mobilebert",
    "electra",
    "megatron-bert",
]

_REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
    if isinstance(item, dict):
        _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item))
    else:
        _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item))

_SPECIAL_SUPPORTED_MODELS = [
    GPT2DoubleHeadsModel,
]
_SUPPORTED_MODELS = tuple(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)

_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES:
    if isinstance(item, dict):
        _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(**item))
    else:
        _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(item))

_SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = []
_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = tuple(
    _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES + _SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES
)


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class HFProxy(Proxy):
    """
    Proxy that is able to provide the proper ranks, shapes and boolean values during symbolic tracing by implementing
    the dim, size and __bool__ methods. It can be easily extended by either adding new methods or extending the
    existing ones.
    """

    def __init__(self, node: Node, tracer: Optional[Tracer] = None):
        super().__init__(node, tracer=tracer)
        if hasattr(self, "tracer") and self.tracer is not None:
            self.device = self.tracer.root.device
            self.dtype = next(self.tracer.root.parameters()).dtype

    @property
    def shape(self):
        return self.size()

    def __setitem__(self, key, value):
        pass

    def __contains__(self, key):
        return False


149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def _wrap_method_for_model_recording(model, method_name, cache_name):
    """Helper function that wraps a torch.Tensor method to record its outputs during forward pass."""
    method = getattr(torch.Tensor, method_name)

    @functools.wraps(method)
    def wrapped(*args, **kwargs):
        if not hasattr(model, cache_name):
            setattr(model, cache_name, [])
        cache = getattr(model, cache_name)
        res = method(*args, **kwargs)
        cache.append(res)
        return res

    return wrapped


def _create_recorded_proxy_method(proxy, method_name, cache_name):
    """
    Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values
    during symbolic tracing.
    """

    def method(self, *args, **kwargs):
        cache = getattr(self.tracer.root, cache_name)
        res = cache.pop(0)
        return res

    method.__name__ = method_name
    bound_method = method.__get__(proxy, proxy.__class__)
    setattr(proxy, method_name, bound_method)


def _wrap_method_for_model_tracing(model, method_name, cache_name):
    """
    Helper function that sets a recorded torch.Tensor method as a torch.Tensor method that will use the recorded values
    during symbolic tracing.
    """

    original_method = getattr(torch.Tensor, method_name)

    @functools.wraps(original_method)
    def method(*args, **kwargs):
        cache = getattr(model, cache_name)
        res = cache.pop(0)
        return res

    setattr(torch.Tensor, method_name, method)

    if method_name == "size":
        setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))


def _monkey_patch_tensor_methods_for_model_recording(model, method_names):
    """
    Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference
    before symbolic tracing.
    """
    cache_names = dict()
    original_methods = dict()
    for method_name in method_names:
        cache_name = f"cache_{method_name}"
        cache_names[method_name] = cache_name
        if not hasattr(torch.Tensor, method_name):
            logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.")
            continue
        original_methods[method_name] = getattr(torch.Tensor, method_name)
        setattr(torch.Tensor, method_name, _wrap_method_for_model_recording(model, method_name, cache_name))

        if method_name == "size":
            original_methods["shape"] = torch.Tensor.shape
            setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))

    return cache_names, original_methods


def _reset_tensor_methods(original_methods):
    """Helper function that resets the monkey patched torch.Tensor methods to their original values."""
    for name, method in original_methods.items():
        setattr(torch.Tensor, name, method)


230
231
class HFTracer(Tracer):
    """
232
233
    Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
    regular PyTorch torch.fx.Proxy.
234
235
    """

236
237
    default_methods_to_record = {"__bool__", "size", "dim"}

238
239
    def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1):
        super().__init__()
240
241
242
243
244
245
246
247

        if not is_torch_fx_available():
            torch_version = version.parse(importlib_metadata.version("torch"))
            raise ImportError(
                f"Found an incompatible version of torch. Found version {torch_version}, but only version "
                f"{TORCH_FX_REQUIRED_VERSION} is supported."
            )

248
        encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length
249
250
251
        decoder_sequence_length = (
            sequence_length[1] if isinstance(sequence_length, (list, tuple)) else encoder_sequence_length
        )
252
253
254
255
256
257
        self.encoder_shape = [batch_size, encoder_sequence_length]
        self.decoder_shape = (
            [batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape)
        )
        self.num_choices = num_choices
        if self.num_choices > 0:
258
259
            self.encoder_shape = [batch_size, self.num_choices, encoder_sequence_length]
            self.decoder_shape = [batch_size, self.num_choices, decoder_sequence_length]
260
261

        self.prev_module = None
262
        self.recorded_methods = None
263
264

    def proxy(self, node: Node):
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        p = HFProxy(node, self)
        if self.recorded_methods:
            for method_name, cache_name in self.recorded_methods.items():
                _create_recorded_proxy_method(p, method_name, cache_name)
        return p

    def _generate_dummy_input(self, model, input_name):
        """Generates dummy input for model inference recording."""
        model_class = model.__class__
        device = model.device
        inputs_dict = dict()

        if input_name in ["labels", "start_positions", "end_positions"]:
            batch_size = self.encoder_shape[0]
            if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
                inputs_dict["labels"] = torch.ones(batch_size, dtype=torch.long, device=device)
            elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
                inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
                inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
            elif model_class in [
                *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
                *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
                *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
            ]:
                inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
            elif model_class in [
                *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
                *get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
                *get_values(MODEL_FOR_MASKED_LM_MAPPING),
                *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
                GPT2DoubleHeadsModel,
            ]:
                inputs_dict["labels"] = torch.zeros(self.decoder_shape, dtype=torch.long, device=device)
            elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
                inputs_dict["labels"] = torch.zeros(self.encoder_shape, dtype=torch.long, device=device)
            else:
                raise NotImplementedError(f"{model_class} not supported yet.")

        elif "mask" in input_name or "ids" in input_name:
            shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape
            inputs_dict[input_name] = torch.ones(shape, dtype=torch.long, device=device)
        else:
            shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape
            shape += [model.config.hidden_size]
            inputs_dict[input_name] = torch.ones(shape, dtype=torch.float, device=device)

        return inputs_dict

    def record(self, model, input_names, method_names=None):
        """
        Records torch.Tensor method outputs (specified by the method_names list) that will then be used during symbolic
        tracing.
        """
        if method_names is None:
            method_names = self.default_methods_to_record

321
        inputs = {}
322
323
324
325
326
327
328
329
330
        for input_name in input_names:
            inputs.update(self._generate_dummy_input(model, input_name))

        clone = copy.deepcopy(model)
        cache_names, original_methods = _monkey_patch_tensor_methods_for_model_recording(clone, method_names)
        self.original_methods = original_methods

        clone(**inputs)

331
332
333
334
        # Useful because sometime the config is changed at inference time, for instance for
        # classification tasks where config.problem_type can be set.
        model.config = clone.config

335
336
337
338
339
340
341
342
343
        _reset_tensor_methods(original_methods)

        self.recorded_methods = {
            method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(clone, cache_name)
        }

        for cache_name in self.recorded_methods.values():
            setattr(model, cache_name, getattr(clone, cache_name))

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
        if isinstance(attr_val, torch.nn.Parameter):
            for n, p in self.root.named_parameters():
                if attr_val is p:
                    if n not in parameter_proxy_cache:
                        parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
                    return parameter_proxy_cache[n]
        # TODO: condition this on wether dynamic axes were requested.
        if isinstance(attr_val, torch.Tensor):
            for n, p in self.root.named_buffers():
                if attr_val is p:
                    if n not in parameter_proxy_cache:
                        parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
                    return parameter_proxy_cache[n]
        return attr_val

360
    def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph:
361
362
363
        if concrete_args is None:
            concrete_args = {}

364
365
366
367
368
369
370
371
372
373
374
375
        sig = inspect.signature(root.forward)
        input_names = sig.parameters.keys() - concrete_args.keys()

        self.record(root, input_names, method_names=method_names)

        for method_name, cache_name in self.recorded_methods.items():
            _wrap_method_for_model_tracing(root, method_name, cache_name)

        graph = super().trace(root, concrete_args=concrete_args)

        _reset_tensor_methods(self.original_methods)

376
377
378
379
380
381
382
383
384
385
386
387
388
        # TODO: keep this until necessary.
        # This is necessary because concrete args are added as input to the traced module since
        # https://github.com/pytorch/pytorch/pull/55888.
        # A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet.
        for node in graph.nodes:
            if node.op == "placeholder":
                # Removing default values for inputs as the forward pass will fail with them.
                if node.target in input_names:
                    node.args = ()
                # It is a concrete arg so it is not used and should be removed.
                else:
                    graph.erase_node(node)

389
        return graph
390
391
392
393
394

    def _insert_module_as_submodule(self, mod):
        """
        Helper method which tries to insert a module that was not declared as submodule.
        """
395
396
397
398
399
400
401
402
        idx = 0
        mod_name = mod.__class__.__name__.lower()
        path = f"{mod_name}_{idx}"
        while hasattr(self.root, path):
            path = f"{mod_name}_{idx}"
            idx += 1

        self.root.add_module(path, mod)
403
404
        return path

405
    def path_of_module(self, mod: nn.Module) -> str:
406
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
407
408
409
        Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has
        a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the
        string "foo.bar".
410
411

        Args:
412
            mod (str): The `Module` to retrieve the qualified name for.
413
414
415
416
417
418
419
        """
        # Prefer the O(1) algorithm
        if hasattr(self, "submodule_paths") and self.submodule_paths:
            path = self.submodule_paths.get(mod)
            if path is None:
                path = self._insert_module_as_submodule(mod)
            if path is None:
420
                raise NameError(f"Module named {mod._get_name()} is not installed as a submodule")
421
422
423
424
425
426
427
428
429
430
431
432
            self.prev_module = path
            return path

        # O(N^2) fallback in the case that we didn't store the submodule
        # paths.
        else:
            for n, p in self.root.named_modules():
                if mod is p:
                    self.prev_module = n
                    return n
            path = self._insert_module_as_submodule(mod)
            if path is None:
433
                raise NameError(f"Module {mod._get_name()} is not installed as a submodule")
434
435
436
            self.prev_module = path
            return path

437
438
439
440
441
    def create_arg(self, a: Any) -> Argument:
        if isinstance(a, range):
            return super().create_arg(list(a))
        return super().create_arg(a)

442

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
@transformation
def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]:
    """
    Prepares a GraphModule produced by symbolic_trace for retracing by:

        - Caching all the attributes specific to the way the model was initially traced
        - Patching back the model to a "static input shapes" version if it was traced to accept dynamic input shapes
    For instance, the need to retrace a GraphModule can happen when applying quantization.
    """
    attributes = _cache_attributes(gm)
    _patch_arguments_(gm, gm.dynamic2static)

    return gm, attributes


def restore_after_retracing_(gm: GraphModule, attributes: Dict[str, Any]):
    """Restores a GraphModule that was retraced to its initial state in terms of static / dynamic input shapes."""
    _restore_attributes_(gm, attributes)
    # transform_to_dynamic_input_ will override the static2dynamic and dynamic2static dictionaries which is the desired
    # behaviour as the previously restored dictionaries contain nodes from the original GraphModule as values.
    transform_to_dynamic_input_(gm, is_retracing=True)
    _patch_arguments_(gm, gm.static2dynamic)
    return gm


def retrace_graph_with(
    gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None
) -> GraphModule:
    """
    Retraces a GraphModule by either using a tracer or a function using a tracer (for instance
    torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and
    restoring anything necessary after the retrace.
    """
    if tracer is None and func is None:
        raise ValueError("Either a tracer or a function using a tracer must be provided.")
    elif tracer is not None and func is not None:
        raise ValueError("Either provide a tracer or a function using a tracer, but not both.")
    else:
        gm, attributes = prepare_for_retracing(gm)
        tracing_func = tracer.trace if tracer else func
        traced = tracing_func(gm)
        restore_after_retracing_(traced, attributes)
        return traced


def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
    if forbidden_values is None:
        forbidden_values = []
    value = random.randint(low, high)
    while value in forbidden_values:
        value = random.randint(low, high)
    return value


497
498
499
500
def symbolic_trace(
    model: PreTrainedModel,
    input_names: Optional[List[str]] = None,
    batch_size: int = 1,
501
    sequence_length: Union[int, List[int], Tuple[int]] = (128, 128),
502
503
504
505
506
507
508
    num_choices: int = -1,
) -> GraphModule:

    """
    Performs symbolic tracing on the model.

    Args:
509
        model ([`PretrainedModel`]):
510
            The model to trace.
511
        input_names (`List[str]`, *optional*):
512
            The names of the inputs of the traced model. If unset, model.dummy_inputs().keys() are used instead.
513
        batch_size (`int`, *optional*, defaults to 1):
514
            The batch size of the traced model inputs.
515
        sequence_length (`int` or `List[int]]`):
516
            The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence
Sylvain Gugger's avatar
Sylvain Gugger committed
517
518
            lengths between the encoder and the decoder inputs, this must be `[encoder_sequence_length,
            decoder_sequence_length]`.
519
        num_choices (`int`, *optional*, defaults to -1):
520
521
522
            The number of possible choices for a multiple choice task.

    Returns:
523
524
525
526
527
528
        `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.

    Example:

    ```python
    from transformers.utils.fx import symbolic_trace
Sylvain Gugger's avatar
Sylvain Gugger committed
529

530
531
532
533
534
535
536
    traced_model = symbolic_trace(
        model,
        input_names=["input_ids", "attention_mask", "token_type_ids"],
        batch_size=1,
        sequence_length=128,
    )
    ```"""
537
538
539
540
541
542
    if input_names is None:
        input_names = model.dummy_inputs.keys()

    sig = inspect.signature(model.forward)
    concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}

543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
    # Preparing HFTracer batch_size and sequence_lenght values for potential dynamic axes.
    use_dynamic_batch_size = batch_size <= 0
    if isinstance(sequence_length, (list, tuple)):
        use_dynamic_sequence_length = sequence_length[0] <= 0 or sequence_length[1] <= 0
    else:
        use_dynamic_sequence_length = sequence_length <= 0

    if use_dynamic_batch_size or use_dynamic_sequence_length:
        forbidden_values = [
            model.config.num_attention_heads,
            model.config.hidden_size,
            model.config.hidden_size // model.config.num_attention_heads,
        ]
        if use_dynamic_batch_size:
            batch_size = _generate_random_int(forbidden_values=forbidden_values)
        forbidden_values.append(batch_size)
        if use_dynamic_sequence_length:
            encoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values)
            forbidden_values.append(encoder_sequence_length)
            decoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values)
            sequence_length = [encoder_sequence_length, decoder_sequence_length]

    if not isinstance(model, _SUPPORTED_MODELS):
        supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS))
        raise NotImplementedError(
            f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
        )
    if (use_dynamic_batch_size or use_dynamic_sequence_length) and not isinstance(
        model, _SUPPORTED_MODELS_FOR_DYNAMIC_AXES
    ):
        supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS_FOR_DYNAMIC_AXES))
        raise NotImplementedError(
            f"Dynamic axes are not supported for {model.__class__.__name__} yet, supported models: {supported_model_names}"
        )

    # Tracing.
579
    tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices)
580

581
582
583
    traced_graph = tracer.trace(model, concrete_args=concrete_args)
    traced = torch.fx.GraphModule(model, traced_graph)

584
585
586
587
588
589
590
591
592
593
594
595
596
597
    traced.config = copy.deepcopy(model.config)
    traced.num_choices = num_choices
    traced.dummy_inputs = {}

    for name in input_names:
        traced.dummy_inputs.update(tracer._generate_dummy_input(model, name))

    traced.use_dynamic_batch_size = use_dynamic_batch_size
    traced.use_dynamic_sequence_length = use_dynamic_sequence_length
    traced.static_batch_size = batch_size
    traced.static_sequence_length = sequence_length

    transform_to_dynamic_input_(traced)

598
    return traced