fx.py 43.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Michael Benayoun's avatar
Michael Benayoun committed
16
import builtins
17
import collections
18
import functools
19
import inspect
20
import math
21
import operator
22
import os
23
import random
Michael Benayoun's avatar
Michael Benayoun committed
24
import warnings
25
from typing import Any, Callable, Dict, List, Optional, Type, Union
26
27

import torch
28
from packaging import version
29
from torch import nn
Michael Benayoun's avatar
Michael Benayoun committed
30
from torch.fx import Graph, GraphModule, Proxy, Tracer
31
from torch.fx.proxy import ParameterProxy
32

33
from .. import PretrainedConfig, PreTrainedModel, logging
34
from ..models.auto import get_values
35
from ..models.auto.modeling_auto import (
36
    MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
37
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
38
    MODEL_FOR_CTC_MAPPING_NAMES,
39
    MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
40
41
42
43
44
45
46
47
48
49
50
51
52
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
    MODEL_FOR_MASKED_LM_MAPPING_NAMES,
    MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
    MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
    MODEL_FOR_PRETRAINING_MAPPING_NAMES,
    MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
    MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
    MODEL_MAPPING_NAMES,
)
53
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
54
from ..utils.versions import importlib_metadata
55
56
57


logger = logging.get_logger(__name__)
58
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
59
60


61
def _generate_supported_model_class_names(
62
63
    model_name: Type[PretrainedConfig],
    supported_tasks: Optional[Union[str, List[str]]] = None,
64
) -> List[str]:
65

66
    task_mapping = {
67
68
69
70
71
72
73
74
        "default": MODEL_MAPPING_NAMES,
        "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
        "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
        "masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
        "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
        "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
        "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
        "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
75
        "document-question-answering": MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
76
77
78
79
80
        "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
        "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
        "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
        "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
        "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
81
82
        "ctc": MODEL_FOR_CTC_MAPPING_NAMES,
        "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
83
84
85
86
87
88
89
    }

    if supported_tasks is None:
        supported_tasks = task_mapping.keys()
    if isinstance(supported_tasks, str):
        supported_tasks = [supported_tasks]

90
    model_class_names = []
91
    for task in supported_tasks:
92
93
94
        class_name = task_mapping[task].get(model_name, None)
        if class_name:
            model_class_names.append(class_name)
95

96
    return model_class_names
97
98
99
100


_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
    "albert",
101
    "bart",
102
    "bert",
103
104
    "blenderbot",
    "blenderbot-small",
105
    "bloom",
106
    "clip",
107
    "convnext",
108
109
    "deberta",
    "deberta-v2",
110
    "distilbert",
NielsRogge's avatar
NielsRogge committed
111
    "donut-swin",
112
113
114
    "electra",
    "gpt2",
    "gpt_neo",
115
    "gptj",
116
    "hubert",
117
    "layoutlm",
118
    "lxmert",
119
120
121
122
123
124
    "m2m_100",
    "marian",
    "mbart",
    "megatron-bert",
    "mobilebert",
    "mt5",
125
    "nezha",
126
127
128
    "opt",
    "pegasus",
    "plbart",
129
    "resnet",
130
    "roberta",
131
132
    "speech_to_text",
    "speech_to_text_2",
133
    "swin",
134
135
136
137
    "t5",
    "trocr",
    "vit",
    "xglm",
138
    "wav2vec2",
139
    #    "xlnet",
140
141
142
143
144
]

_REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
    if isinstance(item, dict):
145
        _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
146
    else:
147
        _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
148
149

_SPECIAL_SUPPORTED_MODELS = [
150
151
152
153
154
    "CLIPTextModel",
    "CLIPVisionModel",
    "GPT2DoubleHeadsModel",
    "Speech2Text2Decoder",
    "TrOCRDecoder",
155
156
    # TODO: add support for them as it should be quite easy to do so (small blocking issues).
    # XLNetForQuestionAnswering,
157
]
158
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
159
160


161
def torch_nn_embedding(self, input):
Michael Benayoun's avatar
Michael Benayoun committed
162
163
164
    return torch.empty(*input.shape, self.weight.shape[-1], device="meta")


165
166
167
168
169
170
def torch_nn_functional_embedding(
    input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
):
    return torch.empty(*input.shape, weight.shape[-1], device="meta")


171
def torch_nn_layernorm(self, input):
Michael Benayoun's avatar
Michael Benayoun committed
172
173
174
    return input


175
176
177
178
def torch_nn_groupnorm(self, input):
    return input


179
def torch_nn_linear(self, input):
Michael Benayoun's avatar
Michael Benayoun committed
180
181
182
    return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")


183
def torch_relu(x):
Michael Benayoun's avatar
Michael Benayoun committed
184
185
186
    return x


187
def torch_nn_relu(self, x):
Michael Benayoun's avatar
Michael Benayoun committed
188
189
190
    return x


191
def torch_nn_functional_relu(x, inplace=False):
Michael Benayoun's avatar
Michael Benayoun committed
192
193
194
195
196
    if not inplace:
        raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
    return x


197
def torch_where(condition, x, y):
Michael Benayoun's avatar
Michael Benayoun committed
198
199
200
201
202
    # torch.where returns the broadcasted tensor of condition, x, and y,
    # so hack it by using addition
    return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")


203
204
def torch_abs(input, *, out=None):
    if out is not None:
Michael Benayoun's avatar
Michael Benayoun committed
205
206
207
208
        raise ValueError("Don't support in-place abs for MetaTensor analysis")
    return input


209
def torch_arange(*args, **kwargs):
Michael Benayoun's avatar
Michael Benayoun committed
210
211
212
213
214
215
216
217
218
    n = len(args)
    step = 1
    if n == 1:
        start = 0
        end = args[0]
    elif n == 2:
        start, end = args
    else:
        start, end, step = args
219
220
221
222
223
224
    if isinstance(start, float):
        start = int(start)
    if isinstance(end, float):
        start = int(end)
    if isinstance(step, float):
        step = int(step)
Michael Benayoun's avatar
Michael Benayoun committed
225
226
227
228
229
    step = kwargs.get("step", step)
    dtype = kwargs.get("dtype")
    return torch.empty((end - start) // step, dtype=dtype, device="meta")


230
def torch_cat(tensors, dim=None, axis=None, *, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
231
232
233
234
235
236
237
238
239
240
241
242
243
    if dim is None and axis is None:
        dim = 0
    if dim is None and axis is not None:
        dim = axis
    if dim < 0:
        dim = tensors[0].dim() + dim
    shapes = [t.shape for t in tensors]
    shape = list(shapes[0])
    concatenated_dim = sum(shape[dim] for shape in shapes)
    final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
    return torch.empty(final_shape, device="meta")


244
def torch_stack(tensors, dim=None, axis=None, *, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
245
246
247
248
249
250
251
252
253
254
255
    if dim is None and axis is None:
        dim = 0
    if dim is None and axis is not None:
        dim = axis
    if dim < 0:
        dim = tensors[0].dim() + 1 + dim
    shape = list(tensors[0].shape)
    shape.insert(dim, len(tensors))
    return torch.empty(shape, device="meta")


256
def torch_add(input, other, *, alpha=1, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
257
258
259
260
261
262
263
264
265
266
267
268
269
    if not isinstance(input, torch.Tensor):
        return torch.empty_like(other, device="meta")
    if not isinstance(other, torch.Tensor):
        return torch.empty_like(input, device="meta")
    max_length = max(input.dim(), other.dim())
    input_shape = list(input.shape) + [1] * (max_length - input.dim())
    other_shape = list(other.shape) + [1] * (max_length - other.dim())
    shape = []
    for i in range(max_length):
        shape.append(max(input_shape[i], other_shape[i]))
    return torch.empty(shape, device="meta")


270
271
def torch_mul(input, other, *, out=None):
    return torch_add(input, other, out=out)
Michael Benayoun's avatar
Michael Benayoun committed
272
273


274
275
def torch_tensor_mul(self, other):
    return torch_mul(self, other)
Michael Benayoun's avatar
Michael Benayoun committed
276
277


278
def torch_matmul(input, other, *, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    d1 = input.dim()
    d2 = other.dim()
    shape = None
    if d1 == 1 and d2 == 1:
        shape = None
    elif d1 == 2 and d2 == 2:
        shape = (input.size(0), other.size(1))
    elif d1 == 1 and d2 == 2:
        shape = (other.size(1),)
    elif d1 == 2 and d1 == 1:
        shape = (input.size(0),)
    else:
        max_length = max(input.dim(), other.dim())
        shape1 = list(input.shape)
        shape2 = list(other.shape)
        if d1 == 1:
            shape1 = [1] + shape1
        if d2 == 1:
            shape2.append(1)
        shape1 = [-1] * (max_length - d1) + list(input.shape)
        shape2 = [-1] * (max_length - d2) + list(other.shape)
        shape = []
        for i in range(max_length):
            shape.append(max(shape1[i], shape2[i]))
        shape[-2] = shape1[-2]
        shape[-1] = shape2[-1]
        if d1 == 1:
            shape.pop(-2)
        if d2 == 1:
            shape.pop(-1)
    if shape is None:
        return torch.tensor(0.0, device="meta")
    return torch.empty(*shape, device="meta")


314
315
def torch_bmm(input, mat2, *, out=None):
    if out is not None:
316
        raise ValueError("Don't support in-place bmm for MetaTensor analysis")
317
318
319
320
321
    batch_size, n, m = input.shape
    _, _, p = mat2.shape
    return torch.empty(batch_size, n, p, device="meta")


322
323
324
325
326
327
328
329
330
331
def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):
    if out is not None:
        raise ValueError("Don't support in-place baddbmm for MetaTensor analysis")
    return torch_bmm(batch1, batch2)


def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None):
    return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out)


332
333
334
335
336
337
338
def torch_einsum(equation, *operands):
    # TODO: infer shape without performing the computation, this might be quite hard.
    concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
    return torch.einsum(equation, *concrete_operands).to("meta")


def torch_tensor_repeat(self, *sizes):
Michael Benayoun's avatar
Michael Benayoun committed
339
340
341
342
343
344
345
346
347
348
349
350
351
    shape = list(self.shape)
    for i, x in enumerate(sizes):
        shape[i] *= x
    return torch.empty(shape, device="meta")


def torch_index_select(input, dim, index, *, out=None):
    shape = list(input.shape)
    shape[dim] = len(index)
    return torch.empty(*shape, device="meta")


def torch_tensor_index_select(self, dim, index):
352
    return torch_index_select(self, dim, index)
Michael Benayoun's avatar
Michael Benayoun committed
353
354


355
356
357
358
def torch_roll(input, shifts, dims=None):
    return input


359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def torch_flip(input, dims):
    return input


def torch_tensor_flip(self, dims):
    return self


def torch_nn_conv1d(self, input):
    l_in = input.shape[-1]
    shape = None
    padding = self.padding
    if padding == "valid":
        padding = (0, 0)
    if padding == "same":
        shape = list(input.shape)
    if shape is None:
        shape = list(input.shape)
        l_out = math.floor(
            (l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
        )
        shape[-1] = l_out
    shape[-2] = self.out_channels
    return torch.empty(shape, device="meta")


385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
def torch_nn_conv2d(self, input):
    h_in, w_in = input.shape[-2:]
    shape = None
    padding = self.padding
    if padding == "valid":
        padding = (0, 0)
    if padding == "same":
        shape = list(input.shape)
    if shape is None:
        shape = list(input.shape)
        h_out = math.floor(
            (h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
        )
        w_out = math.floor(
            (w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
        )
        shape[-2:] = [h_out, w_out]
    shape[-3] = self.out_channels
    return torch.empty(shape, device="meta")


406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
def torch_squeeze(input, dim=None):
    shape = list(input.shape)
    if dim is not None:
        if dim < 0:
            dim = input.dim() + dim
        if shape[dim] == 1:
            shape.pop(dim)
    else:
        new_shape = []
        for dim_value in shape:
            if dim_value == 1:
                continue
            new_shape.append(dim_value)
        shape = new_shape
    return torch.empty(shape, device="meta")


def torch_tensor_squeeze(self, dim=None):
    return torch_squeeze(self, dim)


427
428
429
430
431
432
433
434
435
436
437
438
def torch_unsqueeze(input, dim):
    shape = list(input.shape)
    if dim < 0:
        dim = input.dim() + 1 + dim
    shape.insert(dim, 1)
    return torch.empty(shape, device="meta")


def torch_tensor_unsqueeze(self, dim):
    return torch_unsqueeze(self, dim)


439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
def torch_unique_consecutive(input, **kwargs):
    output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
    if isinstance(output, torch.Tensor):
        return output.to("meta")
    else:
        return tuple(map(output, lambda x: x.to("meta")))


def torch_nn_functional_one_hot(tensor, num_classes=-1):
    if num_classes < 0:
        raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
    shape = list(tensor.shape) + [num_classes]
    return torch.empty(shape, device="meta")


Michael Benayoun's avatar
Michael Benayoun committed
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
def torch_nn_mseloss(self, input, target):
    if self.reduction == "none":
        shape = target.shape
    else:
        shape = (1,)
    return torch.empty(shape, device="meta")


def torch_nn_crossentropyloss(self, input, target):
    if self.reduction == "none":
        shape = target.shape
    else:
        shape = (1,)
    return torch.empty(shape, device="meta")


def torch_nn_bcewithlogitsloss(self, input, target):
    if self.reduction == "none":
        shape = target.shape
    else:
        shape = (1,)
    return torch.empty(shape, device="meta")


478
def operator_getitem(a, b):
479
480
481
482
483
484
485
486
    def to_concrete(t):
        if isinstance(t, torch.Tensor):
            concrete = torch.ones_like(t, device="cpu")
            if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
                concrete = concrete.to(torch.int64)
            return concrete
        return t

487
488
    if isinstance(a, torch.Tensor):
        # TODO: infer shape without performing the computation.
489
490
491
492
        if isinstance(b, tuple):
            b = tuple(map(to_concrete, b))
        else:
            b = to_concrete(b)
493
494
495
496
        return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
    return operator.getitem(a, b)


Michael Benayoun's avatar
Michael Benayoun committed
497
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
498
    torch.nn.Embedding: torch_nn_embedding,
499
    torch.nn.functional.embedding: torch_nn_functional_embedding,
500
    torch.nn.LayerNorm: torch_nn_layernorm,
501
    torch.nn.GroupNorm: torch_nn_groupnorm,
502
503
504
505
506
507
508
509
510
511
512
513
514
    torch.nn.Linear: torch_nn_linear,
    torch.relu: torch_relu,
    torch.nn.functional.relu: torch_nn_functional_relu,
    torch.nn.ReLU: torch_nn_relu,
    torch.where: torch_where,
    torch.abs: torch_abs,
    torch.arange: torch_arange,
    torch.cat: torch_cat,
    torch.stack: torch_stack,
    torch.add: torch_add,
    torch.mul: torch_mul,
    torch.Tensor.mul: torch_tensor_mul,
    torch.matmul: torch_matmul,
515
    torch.bmm: torch_bmm,
516
517
    torch.baddbmm: torch_baddbmm,
    torch.Tensor.baddbmm: torch_tensor_baddbmm,
518
519
    torch.einsum: torch_einsum,
    torch.Tensor.repeat: torch_tensor_repeat,
520
    torch.roll: torch_roll,
521
522
523
524
525
    torch.flip: torch_flip,
    torch.Tensor.flip: torch_tensor_flip,
    torch.index_select: torch_index_select,
    torch.Tensor.index_select: torch_tensor_index_select,
    torch.nn.Conv1d: torch_nn_conv1d,
526
    torch.nn.Conv2d: torch_nn_conv2d,
527
528
    torch.squeeze: torch_squeeze,
    torch.Tensor.squeeze: torch_tensor_squeeze,
529
530
    torch.unsqueeze: torch_unsqueeze,
    torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
531
532
    torch.unique_consecutive: torch_unique_consecutive,
    torch.nn.functional.one_hot: torch_nn_functional_one_hot,
Michael Benayoun's avatar
Michael Benayoun committed
533
534
535
    torch.nn.MSELoss: torch_nn_mseloss,
    torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
    torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
536
    operator.getitem: operator_getitem,
Michael Benayoun's avatar
Michael Benayoun committed
537
538
539
}


540
541
class HFProxy(Proxy):
    """
Michael Benayoun's avatar
Michael Benayoun committed
542
    Proxy that uses metadata to handle data-dependent control-flow.
543
544
    """

Michael Benayoun's avatar
Michael Benayoun committed
545
546
    def install_metadata(self, metadata):
        self._metadata = metadata
547
548
549

    @property
    def shape(self):
Michael Benayoun's avatar
Michael Benayoun committed
550
        return self.tracer.create_proxy("call_method", "size", (self,), {})
551

Michael Benayoun's avatar
Michael Benayoun committed
552
553
554
555
556
    @property
    def dtype(self):
        if hasattr(self, "_metadata") and self._metadata is not None:
            return self._metadata.dtype
        return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
557

Michael Benayoun's avatar
Michael Benayoun committed
558
559
560
561
562
    @property
    def device(self):
        # Hack so we can track when devices are used. During meta-tensor propagation,
        # replace these values with a constant 'meta'
        return MetaDeviceAttribute(self, "device")
563

564
    def __len__(self):
Michael Benayoun's avatar
Michael Benayoun committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
        if hasattr(self, "_metadata") and self._metadata is not None:
            return len(self._metadata)
        return super().__len__()

    def __bool__(self):
        if hasattr(self, "_metadata") and self._metadata is not None:
            return self._metadata
        return super().__bool__()

    def __getattr__(self, k):
        if k == "_metadata":
            return self.__getattribute__(k)
        # note: not added to the graph yet, if this is a method call
        # we peephole optimize to the method invocation
        return HFAttribute(self, k)
580

581
    def __setitem__(self, indices, values):
582
        return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
583

Michael Benayoun's avatar
Michael Benayoun committed
584
    def __contains__(self, key):
585
586
        if hasattr(self, "_metadata") and self._metadata is not None:
            return key in self._metadata
Michael Benayoun's avatar
Michael Benayoun committed
587
        return super().__contains__(key)
588
589


Michael Benayoun's avatar
Michael Benayoun committed
590
591
592
593
594
595
class HFAttribute(HFProxy):
    def __init__(self, root, attr: str):
        self.root = root
        self.attr = attr
        self.tracer = root.tracer
        self._node = None
596

Michael Benayoun's avatar
Michael Benayoun committed
597
598
599
600
601
602
603
    @property
    def node(self):
        # the node for attributes is added lazily, since most will just be method calls
        # which do not rely on the getitem call
        if self._node is None:
            self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
        return self._node
604

Michael Benayoun's avatar
Michael Benayoun committed
605
606
    def __call__(self, *args, **kwargs):
        return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
607
608


Michael Benayoun's avatar
Michael Benayoun committed
609
610
class MetaDeviceAttribute(HFAttribute):
    pass
611

612

Michael Benayoun's avatar
Michael Benayoun committed
613
614
615
616
617
618
619
620
621
def _proxies_to_metas(v):
    """Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
    if isinstance(v, MetaDeviceAttribute):
        return "meta"
    if isinstance(v, torch.fx.Proxy):
        if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
            raise RuntimeError(f"No metadata was found for {v}")
        return v._metadata
    return v
622

623

Michael Benayoun's avatar
Michael Benayoun committed
624
625
626
627
def _gen_constructor_wrapper(target):
    @functools.wraps(target)
    def wrapper(*args, **kwargs):
        proxy = None
628

Michael Benayoun's avatar
Michael Benayoun committed
629
630
631
632
        def check_has_proxy(v):
            if isinstance(v, Proxy):
                nonlocal proxy
                proxy = v
633

Michael Benayoun's avatar
Michael Benayoun committed
634
635
        torch.fx.node.map_aggregate(args, check_has_proxy)
        torch.fx.node.map_aggregate(kwargs, check_has_proxy)
636

Michael Benayoun's avatar
Michael Benayoun committed
637
638
639
640
641
642
        if proxy is not None:
            return proxy.tracer.create_proxy("call_function", target, args, kwargs)
        else:
            return target(*args, **kwargs)

    return wrapper, target
643
644


645
646
647
648
649
650
651
652
653
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
    if forbidden_values is None:
        forbidden_values = []
    value = random.randint(low, high)
    while value in forbidden_values:
        value = random.randint(low, high)
    return value


654
655
class HFTracer(Tracer):
    """
656
657
    Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
    regular PyTorch torch.fx.Proxy.
658
659
    """

660
661
    # Feature flag for proxying accesses to buffer values
    proxy_buffer_attributes: bool = True
Michael Benayoun's avatar
Michael Benayoun committed
662
    allow_insert_stateless_mods: bool = True
663
    _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
664

665
    def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
666

667
        super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
668
669
670
671
672
673
674
675

        if not is_torch_fx_available():
            torch_version = version.parse(importlib_metadata.version("torch"))
            raise ImportError(
                f"Found an incompatible version of torch. Found version {torch_version}, but only version "
                f"{TORCH_FX_REQUIRED_VERSION} is supported."
            )

676
677
678
    def _generate_dummy_input(
        self, model: PreTrainedModel, input_name: str, shape: List[int]
    ) -> Dict[str, torch.Tensor]:
679
        """Generates dummy input for model inference recording."""
680
681
        # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
        # from pickle, or from the "__class__" attribute in the general case.
682
        model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
683
        device = model.device
684
        inputs_dict = {}
685
686

        if input_name in ["labels", "start_positions", "end_positions"]:
Michael Benayoun's avatar
Michael Benayoun committed
687

688
            batch_size = shape[0]
689
690
691
692
693
694
            if model_class_name in [
                *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
                *get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
                *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
                *get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
            ]:
695
                inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
696
697
            elif model_class_name in [
                *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
698
                *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
699
                "XLNetForQuestionAnswering",
700
            ]:
701
702
                inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
                inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
703
            elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
Michael Benayoun's avatar
Michael Benayoun committed
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
                if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
                    raise ValueError(
                        "Could not retrieve the problem type for the sequence classification task, please set "
                        'model.config.problem_type to one of the following values: "regression", '
                        '"single_label_classification", or "multi_label_classification".'
                    )

                if model.config.problem_type == "regression":
                    labels_shape = (batch_size, model.config.num_labels)
                    labels_dtype = torch.float32
                elif model.config.problem_type == "single_label_classification":
                    labels_shape = (batch_size,)
                    labels_dtype = torch.long
                elif model.config.problem_type == "multi_label_classification":
                    labels_shape = (batch_size, model.config.num_labels)
                    labels_dtype = torch.float32
                else:
                    raise ValueError(
                        'Expected model.config.problem_type to be either: "regression", "single_label_classification"'
                        f', or "multi_label_classification", but "{model.config.problem_type}" was provided.'
                    )
                inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)

727
728
729
730
731
732
733
            elif model_class_name in [
                *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
                *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
                *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
                *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
                *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
                "GPT2DoubleHeadsModel",
734
            ]:
735
                inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
736
            else:
737
738
739
                raise NotImplementedError(
                    f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
                )
740
741
        elif "pixel_values" in input_name:
            batch_size = shape[0]
742
743
744
745
746
747
748
            image_size = getattr(model.config, "image_size", None)
            if image_size is None:
                if hasattr(model.config, "vision_config"):
                    image_size = model.config.vision_config.image_size
                elif hasattr(model.config, "encoder"):
                    image_size = model.config.encoder.image_size
                else:
749
                    image_size = (_generate_random_int(), _generate_random_int())
750
751
752

            # If no num_channels is in the config, use some arbitrary value.
            num_channels = getattr(model.config, "num_channels", 3)
753
754
755
756
            if not isinstance(image_size, collections.abc.Iterable):
                image_size = (image_size, image_size)
            height, width = image_size
            inputs_dict[input_name] = torch.zeros(
757
                batch_size, num_channels, height, width, dtype=torch.float32, device=device
758
            )
759
760
761
762
763
764
        elif "bbox" in input_name:
            inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
        elif "input_features" in input_name:
            inputs_dict[input_name] = torch.zeros(
                *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
            )
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
        elif "visual_feats" in input_name:
            inputs_dict[input_name] = torch.zeros(
                shape
                + [
                    model.config.visual_feat_dim,
                ],
                dtype=torch.float,
                device=device,
            )
        elif "visual_pos" in input_name:
            inputs_dict[input_name] = torch.zeros(
                shape
                + [
                    model.config.visual_pos_dim,
                ],
                dtype=torch.float,
                device=device,
            )
783
784
        elif "inputs" in input_name:
            inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
785
786
787
788
789
        elif "input_values" in input_name:
            batch_size, _ = shape
            # Generating big sequence length for audio inputs.
            seq_length = _generate_random_int(low=10000, high=20000)
            inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
790
        elif "mask" in input_name or "ids" in input_name:
791
            inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
792
        else:
793
794
            shape_with_hidden_size = shape + [model.config.hidden_size]
            inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
795
796
797

        return inputs_dict

Michael Benayoun's avatar
Michael Benayoun committed
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
    def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
        rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)

        if kind == "placeholder" and target in self.meta_args:
            rv.install_metadata(self.meta_args[target])
            return rv

        if target in self.orig_fns:
            # NOTE: tensor constructors in PyTorch define the `device` argument as
            # *kwargs-only*. That is why this works. If you add methods to
            # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
            # this will break and you will likely see issues where we cannot infer
            # the size of the output.
            if "device" in kwargs:
                kwargs["device"] = "meta"

        try:
            args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
            kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)

            if kind == "call_function":
                meta_target = _MANUAL_META_OVERRIDES.get(target, target)
                meta_out = meta_target(*args_metas, **kwargs_metas)
821
822
                if isinstance(meta_out, torch.Tensor):
                    meta_out = meta_out.to(device="meta")
Michael Benayoun's avatar
Michael Benayoun committed
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
            elif kind == "call_method":
                method = getattr(args_metas[0].__class__, target)
                meta_target = _MANUAL_META_OVERRIDES.get(method, method)
                meta_out = meta_target(*args_metas, **kwargs_metas)
            elif kind == "call_module":
                if not hasattr(self, "orig_forward"):
                    raise AttributeError(f"{self} does not have an attribute called orig_forward")
                self._disable_module_getattr = True
                try:
                    mod = self.root.get_submodule(target)
                    mod_type = type(mod)
                    if mod_type in _MANUAL_META_OVERRIDES:
                        meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
                    else:
                        meta_out = self.orig_forward(*args_metas, **kwargs_metas)
                finally:
                    self._disable_module_getattr = False
            elif kind == "get_attr":
                self._disable_module_getattr = True
                try:
                    attr_itr = self.root
                    atoms = target.split(".")
                    for atom in atoms:
                        attr_itr = getattr(attr_itr, atom)
                    if isinstance(attr_itr, torch.Tensor):
                        meta_out = attr_itr.to(device="meta")
                    else:
                        meta_out = attr_itr
                finally:
                    self._disable_module_getattr = False
            else:
                return rv
855

Michael Benayoun's avatar
Michael Benayoun committed
856
857
858
859
            if not isinstance(rv, Proxy):
                raise ValueError("Don't support composite output yet")
            rv.install_metadata(meta_out)
        except Exception as e:
860
861
            if _IS_IN_DEBUG_MODE:
                warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
862

Michael Benayoun's avatar
Michael Benayoun committed
863
        return rv
864

865
    # Replaced by .getattr from PyTorch 1.13
Michael Benayoun's avatar
Michael Benayoun committed
866
867
868
869
    def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
        if getattr(self, "_disable_module_getattr", False):
            return attr_val
        else:
870

871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
            def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
                for n, p in collection_to_search:
                    if attr_val is p:
                        if n not in parameter_proxy_cache:
                            kwargs = {}
                            if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
                                kwargs["proxy_factory_fn"] = (
                                    None
                                    if not self.param_shapes_constant
                                    else lambda node: ParameterProxy(self, node, n, attr_val)
                                )
                            val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs)  # type: ignore[arg-type]
                            parameter_proxy_cache[n] = val_proxy
                        return parameter_proxy_cache[n]
                return None

            if isinstance(attr_val, torch.nn.Parameter):
                maybe_parameter_proxy = maybe_get_proxy_for_attr(
                    attr_val, self.root.named_parameters(), parameter_proxy_cache
                )
                if maybe_parameter_proxy is not None:
                    return maybe_parameter_proxy

            if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
                maybe_buffer_proxy = maybe_get_proxy_for_attr(
                    attr_val, self.root.named_buffers(), parameter_proxy_cache
                )
                if maybe_buffer_proxy is not None:
                    return maybe_buffer_proxy

            return attr_val
902

903
904
905
906
    # Needed for PyTorch 1.13+
    def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
        return self._module_getattr(attr, attr_val, parameter_proxy_cache)

Michael Benayoun's avatar
Michael Benayoun committed
907
908
909
    def call_module(self, m, forward, args, kwargs):
        self.orig_forward = forward
        return super().call_module(m, forward, args, kwargs)
910

Michael Benayoun's avatar
Michael Benayoun committed
911
912
    def proxy(self, node):
        return HFProxy(node, self)
913

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
    def trace(
        self,
        root: Union[torch.nn.Module, Callable[..., Any]],
        concrete_args: Optional[Dict[str, Any]] = None,
        dummy_inputs: Optional[Dict[str, Any]] = None,
        complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
    ) -> Graph:
        """
        Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
        `torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
        the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
        `torch.nn.Module` instance to use as the root and add embedded constants to.

        Args:
            root (`torch.nn.Module` or  `Callable`):
                Either a `torch.nn.Module`` or a function to be traced through. If root is not a
                [`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
            concrete_args (`Dict[str, Any], *optional*):
                Concrete arguments that should not be treated as Proxies
            dummy_inputs (`Dict[str, Any]`, *optional*):
                The dummy inputs needed to handle data-dependent control-flow if `root` is not a
                [`~transformers.PreTrainedModel`]. It can also be used when `root` is a
                [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
            complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
                If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
                `dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.

        Returns:
            `torch.fx.Graph`:
                A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.

        """
        sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)

948
949
950
        if concrete_args is None:
            concrete_args = {}

951
952
953
954
955
956
957
958
        if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
            for param in sig.parameters.values():
                if param.name in dummy_inputs:
                    continue
                if param.default is inspect.Parameter.empty:
                    raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
            concrete_args.update({p.name: p.default for p in sig.parameters.values() if p.name not in dummy_inputs})

959
960
        input_names = sig.parameters.keys() - concrete_args.keys()

Michael Benayoun's avatar
Michael Benayoun committed
961
962
963
964
        # Creating a random input shape to generate dummy inputs.
        batch_size = _generate_random_int()
        sequence_length = _generate_random_int()
        shape = [batch_size, sequence_length]
965

966
        if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
Michael Benayoun's avatar
Michael Benayoun committed
967
968
            num_choices = _generate_random_int(low=2, high=5)
            shape.insert(1, num_choices)
969

970
        inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
Michael Benayoun's avatar
Michael Benayoun committed
971
        for input_name in input_names:
972
973
974
975
976
977
978
979
980
981
982
            if input_name in inputs:
                continue
            # We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
            # be able to use HFTracer._generate_dummy_input.
            if isinstance(root, PreTrainedModel) or type(root).__qualname__.startswith("_deserialize_graph_module"):
                inputs.update(self._generate_dummy_input(root, input_name, shape))
            else:
                raise RuntimeError(
                    f"Could not generate input named {input_name} for because root is not a"
                    " transformers.PreTrainedModel."
                )
983

984
985
986
987
        concrete_metas = {
            input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
            for input_name, input_ in inputs.items()
        }
988
989
990
        for param in sig.parameters.values():
            if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
                concrete_metas[f"**{param.name}"] = {}
Michael Benayoun's avatar
Michael Benayoun committed
991
992
993
994
995
        self.meta_args = concrete_metas
        self.patched_torch_methods = {
            target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
        }
        self.orig_fns = set()
996

Michael Benayoun's avatar
Michael Benayoun committed
997
998
999
        for name, (wrapper, orig) in self.patched_torch_methods.items():
            setattr(torch, name, wrapper)
            self.orig_fns.add(orig)
1000

Michael Benayoun's avatar
Michael Benayoun committed
1001
1002
1003
1004
1005
        try:
            self.graph = super().trace(root, concrete_args=concrete_args)
        finally:
            for name, (_, orig) in self.patched_torch_methods.items():
                setattr(torch, name, orig)
1006

1007
1008
        # This is necessary because concrete args are added as input to the traced module since
        # https://github.com/pytorch/pytorch/pull/55888.
1009
        for node in self.graph.nodes:
1010
1011
1012
1013
            if node.op == "placeholder":
                # Removing default values for inputs as the forward pass will fail with them.
                if node.target in input_names:
                    node.args = ()
1014
1015
1016
                    # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
                    # It cannot infer on the attributes and methods the input should have, and fails.
                    node.type = torch.Tensor
1017
1018
                # It is a concrete arg so it is not used and should be removed.
                else:
1019
1020
1021
1022
1023
1024
1025
1026
1027
                    to_visit = [node]
                    to_delete = collections.OrderedDict()
                    while to_visit:
                        n = to_visit.pop(0)
                        to_delete[n] = None
                        to_visit += list(n.users.keys())

                    for user in reversed(to_delete.keys()):
                        self.graph.erase_node(user)
1028

1029
1030
1031
1032
1033
            # TODO: solves GraphModule creation.
            # Without this, return type annotation "Tuple" is causing code execution failure.
            if node.op == "output":
                node.type = None

1034
        return self.graph
1035

Michael Benayoun's avatar
Michael Benayoun committed
1036
1037
1038
1039
1040
1041
1042
    def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
        """
        Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module
        because its attributes are input-dependent.
        """
        return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())

1043
    def _insert_module_as_submodule(self, mod: nn.Module) -> str:
1044
1045
1046
        """
        Helper method which tries to insert a module that was not declared as submodule.
        """
Michael Benayoun's avatar
Michael Benayoun committed
1047
1048
1049
1050
        # If one of the module attributes is a Proxy, it means that its instantiation is input-dependent.
        # It is not possible to insert such modules, those should be traced through.
        if self._stateless_mod_instanciation_depends_on_proxies(mod):
            return ""
1051
1052
1053
        idx = 0
        mod_name = mod.__class__.__name__.lower()
        path = f"{mod_name}_{idx}"
Michael Benayoun's avatar
Michael Benayoun committed
1054
        already_inserted = False
1055
        while hasattr(self.root, path):
Michael Benayoun's avatar
Michael Benayoun committed
1056
1057
1058
            if getattr(self.root, path) is mod:
                already_inserted = True
                break
1059
1060
1061
            path = f"{mod_name}_{idx}"
            idx += 1

Michael Benayoun's avatar
Michael Benayoun committed
1062
1063
1064
        # No need to add multiple instances of the same module.
        if not already_inserted:
            self.root.add_module(path, mod)
1065
1066
        return path

1067
    def path_of_module(self, mod: nn.Module) -> str:
1068
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1069
1070
1071
        Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has
        a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the
        string "foo.bar".
1072
1073

        Args:
1074
            mod (str): The `Module` to retrieve the qualified name for.
1075
        """
Michael Benayoun's avatar
Michael Benayoun committed
1076
1077
1078
1079
        try:
            return super().path_of_module(mod)
        except NameError as e:
            if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
1080
                path = self._insert_module_as_submodule(mod)
Michael Benayoun's avatar
Michael Benayoun committed
1081
1082
                return path
            raise e
1083

Michael Benayoun's avatar
Michael Benayoun committed
1084
1085
1086
1087
    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
        return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module(
            m, module_qualified_name
        )
1088

1089

1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
def get_concrete_args(model: nn.Module, input_names: List[str]):
    sig = inspect.signature(model.forward)

    if not (set(input_names) <= set(sig.parameters.keys())):
        formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
        formatted_allowed_input_names = ", ".join(sig.parameters.keys())
        raise ValueError(
            f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
            f" {formatted_allowed_input_names}"
        )

    return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}


def check_if_model_is_supported(model: PreTrainedModel):
    if model.__class__.__name__ not in _SUPPORTED_MODELS:
        supported_model_names = ", ".join(_SUPPORTED_MODELS)
        raise NotImplementedError(
            f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
        )


1112
1113
1114
def symbolic_trace(
    model: PreTrainedModel,
    input_names: Optional[List[str]] = None,
1115
    disable_check: bool = False,
1116
1117
1118
1119
1120
1121
) -> GraphModule:

    """
    Performs symbolic tracing on the model.

    Args:
1122
        model ([`PretrainedModel`]):
1123
            The model to trace.
1124
        input_names (`List[str]`, *optional*):
1125
            The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
1126
1127
        disable_check (`bool`, *optional*, defaults to `False`):
            If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
1128
1129

    Returns:
1130
1131
1132
1133
        `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.

    Example:

1134
1135
        ```python
        from transformers.utils.fx import symbolic_trace
Sylvain Gugger's avatar
Sylvain Gugger committed
1136

1137
1138
1139
        traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
        ```
    """
1140
1141
1142
    if input_names is None:
        input_names = model.dummy_inputs.keys()

1143
    input_names = list(input_names)
1144
    concrete_args = get_concrete_args(model, input_names)
1145

1146
1147
    if not disable_check:
        check_if_model_is_supported(model)
1148
1149

    # Tracing.
1150
    tracer = HFTracer()
1151
1152
1153
    traced_graph = tracer.trace(model, concrete_args=concrete_args)
    traced = torch.fx.GraphModule(model, traced_graph)

1154
1155
1156
1157
1158
1159
    traced.config = model.config
    # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
    # _generate_dummy_input, where the model class is needed.
    traced.class_for_deserialization = model.__class__
    traced.device = model.device

1160
    return traced