fx.py 46.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Michael Benayoun's avatar
Michael Benayoun committed
16
import builtins
17
import collections
18
import functools
19
import inspect
20
import math
21
import operator
22
import os
23
import random
Michael Benayoun's avatar
Michael Benayoun committed
24
import warnings
25
from typing import Any, Callable, Dict, List, Optional, Type, Union
26
27

import torch
28
from torch import nn
Michael Benayoun's avatar
Michael Benayoun committed
29
from torch.fx import Graph, GraphModule, Proxy, Tracer
30
from torch.fx._compatibility import compatibility
31
from torch.fx.proxy import ParameterProxy
32

33
from .. import PretrainedConfig, PreTrainedModel, logging
34
from ..models.auto import get_values
35
from ..models.auto.modeling_auto import (
36
    MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
37
    MODEL_FOR_BACKBONE_MAPPING_NAMES,
38
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
39
    MODEL_FOR_CTC_MAPPING_NAMES,
40
    MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
41
42
43
44
45
46
47
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
    MODEL_FOR_MASKED_LM_MAPPING_NAMES,
    MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
    MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
    MODEL_FOR_PRETRAINING_MAPPING_NAMES,
    MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
donguk.lim's avatar
donguk.lim committed
48
    MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
49
50
51
52
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
    MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
53
    MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
54
55
    MODEL_MAPPING_NAMES,
)
56
57
58
59
60
61
62
from ..utils import (
    ENV_VARS_TRUE_VALUES,
    TORCH_FX_REQUIRED_VERSION,
    get_torch_version,
    is_peft_available,
    is_torch_fx_available,
)
63
64


65
66
67
68
if is_peft_available():
    from peft import PeftModel


69
logger = logging.get_logger(__name__)
70
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
71
72


73
def _generate_supported_model_class_names(
74
75
    model_name: Type[PretrainedConfig],
    supported_tasks: Optional[Union[str, List[str]]] = None,
76
) -> List[str]:
77
    task_mapping = {
78
79
80
81
82
83
84
85
        "default": MODEL_MAPPING_NAMES,
        "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
        "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
        "masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
        "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
        "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
        "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
        "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
86
        "document-question-answering": MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
87
88
89
90
91
        "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
        "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
        "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
        "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
        "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
92
        "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
93
94
        "ctc": MODEL_FOR_CTC_MAPPING_NAMES,
        "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
donguk.lim's avatar
donguk.lim committed
95
        "semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
96
        "backbone": MODEL_FOR_BACKBONE_MAPPING_NAMES,
97
98
99
100
101
102
103
    }

    if supported_tasks is None:
        supported_tasks = task_mapping.keys()
    if isinstance(supported_tasks, str):
        supported_tasks = [supported_tasks]

104
    model_class_names = []
105
    for task in supported_tasks:
106
107
108
        class_name = task_mapping[task].get(model_name, None)
        if class_name:
            model_class_names.append(class_name)
109

110
    return model_class_names
111
112
113


_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
Jongjyh's avatar
Jongjyh committed
114
    "altclip",
115
    "albert",
116
    "bart",
117
    "bert",
118
119
    "blenderbot",
    "blenderbot-small",
120
    "bloom",
121
    "clip",
122
    "convnext",
123
124
    "deberta",
    "deberta-v2",
amyeroberts's avatar
amyeroberts committed
125
    "dinov2",
126
    "distilbert",
NielsRogge's avatar
NielsRogge committed
127
    "donut-swin",
128
129
130
    "electra",
    "gpt2",
    "gpt_neo",
131
    "gptj",
132
    "hubert",
133
    "layoutlm",
134
    "lxmert",
135
136
137
138
139
140
    "m2m_100",
    "marian",
    "mbart",
    "megatron-bert",
    "mobilebert",
    "mt5",
141
    "nezha",
142
143
144
    "opt",
    "pegasus",
    "plbart",
145
    "resnet",
146
    "roberta",
donguk.lim's avatar
donguk.lim committed
147
    "segformer",
148
149
    "speech_to_text",
    "speech_to_text_2",
150
    "swin",
151
152
153
154
    "t5",
    "trocr",
    "vit",
    "xglm",
155
    "wav2vec2",
156
    #    "xlnet",
157
158
159
160
161
]

_REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
    if isinstance(item, dict):
162
        _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
163
    else:
164
        _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
165
166

_SPECIAL_SUPPORTED_MODELS = [
167
    "CLIPTextModel",
168
    "CLIPTextModelWithProjection",
169
    "CLIPVisionModel",
170
    "CLIPVisionModelWithProjection",
Jongjyh's avatar
Jongjyh committed
171
172
    "AltCLIPTextModel",
    "AltCLIPVisionModel",
173
    "GitVisionModel",
174
175
176
    "GPT2DoubleHeadsModel",
    "Speech2Text2Decoder",
    "TrOCRDecoder",
177
    "PeftModelForCausalLM",
178
    "PeftModelForSeq2SeqLM",
179
180
    # TODO: add support for them as it should be quite easy to do so (small blocking issues).
    # XLNetForQuestionAnswering,
181
]
182
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
183
184


185
def torch_nn_embedding(self, input):
186
    return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
Michael Benayoun's avatar
Michael Benayoun committed
187
188


189
190
191
def torch_nn_functional_embedding(
    input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
):
192
    return torch.empty(*input.shape, weight.shape[-1], device="meta", dtype=weight.dtype)
193
194


195
def torch_nn_layernorm(self, input):
Michael Benayoun's avatar
Michael Benayoun committed
196
197
198
    return input


199
200
201
202
def torch_nn_groupnorm(self, input):
    return input


203
def torch_nn_linear(self, input):
Michael Benayoun's avatar
Michael Benayoun committed
204
205
206
    return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")


207
def torch_relu(x):
Michael Benayoun's avatar
Michael Benayoun committed
208
209
210
    return x


211
def torch_nn_relu(self, x):
Michael Benayoun's avatar
Michael Benayoun committed
212
213
214
    return x


215
def torch_nn_functional_relu(x, inplace=False):
Michael Benayoun's avatar
Michael Benayoun committed
216
217
218
219
220
    if not inplace:
        raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
    return x


221
def torch_where(condition, x, y):
Michael Benayoun's avatar
Michael Benayoun committed
222
223
224
225
226
    # torch.where returns the broadcasted tensor of condition, x, and y,
    # so hack it by using addition
    return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")


227
228
def torch_abs(input, *, out=None):
    if out is not None:
Michael Benayoun's avatar
Michael Benayoun committed
229
230
231
232
        raise ValueError("Don't support in-place abs for MetaTensor analysis")
    return input


233
def torch_arange(*args, **kwargs):
Michael Benayoun's avatar
Michael Benayoun committed
234
235
236
237
238
239
240
241
242
    n = len(args)
    step = 1
    if n == 1:
        start = 0
        end = args[0]
    elif n == 2:
        start, end = args
    else:
        start, end, step = args
243
244
245
246
247
248
    if isinstance(start, float):
        start = int(start)
    if isinstance(end, float):
        start = int(end)
    if isinstance(step, float):
        step = int(step)
Michael Benayoun's avatar
Michael Benayoun committed
249
250
251
252
253
    step = kwargs.get("step", step)
    dtype = kwargs.get("dtype")
    return torch.empty((end - start) // step, dtype=dtype, device="meta")


254
255
256
257
258
259
260
261
262
def torch_full(*args, **kwargs):
    args = list(args)
    if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"):
        args[1] = 1  # Any value.
    kwargs_without_device = dict(kwargs)
    kwargs_without_device.pop("device", None)
    return torch.full(*args, **kwargs_without_device)


263
def torch_cat(tensors, dim=None, axis=None, *, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
264
265
266
267
268
269
270
271
272
273
274
275
276
    if dim is None and axis is None:
        dim = 0
    if dim is None and axis is not None:
        dim = axis
    if dim < 0:
        dim = tensors[0].dim() + dim
    shapes = [t.shape for t in tensors]
    shape = list(shapes[0])
    concatenated_dim = sum(shape[dim] for shape in shapes)
    final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
    return torch.empty(final_shape, device="meta")


277
def torch_stack(tensors, dim=None, axis=None, *, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
278
279
280
281
282
283
284
285
286
287
288
    if dim is None and axis is None:
        dim = 0
    if dim is None and axis is not None:
        dim = axis
    if dim < 0:
        dim = tensors[0].dim() + 1 + dim
    shape = list(tensors[0].shape)
    shape.insert(dim, len(tensors))
    return torch.empty(shape, device="meta")


289
def torch_add(input, other, *, alpha=1, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
290
291
292
293
294
295
296
297
298
299
300
301
302
    if not isinstance(input, torch.Tensor):
        return torch.empty_like(other, device="meta")
    if not isinstance(other, torch.Tensor):
        return torch.empty_like(input, device="meta")
    max_length = max(input.dim(), other.dim())
    input_shape = list(input.shape) + [1] * (max_length - input.dim())
    other_shape = list(other.shape) + [1] * (max_length - other.dim())
    shape = []
    for i in range(max_length):
        shape.append(max(input_shape[i], other_shape[i]))
    return torch.empty(shape, device="meta")


303
304
def torch_mul(input, other, *, out=None):
    return torch_add(input, other, out=out)
Michael Benayoun's avatar
Michael Benayoun committed
305
306


307
308
def torch_tensor_mul(self, other):
    return torch_mul(self, other)
Michael Benayoun's avatar
Michael Benayoun committed
309
310


311
def torch_matmul(input, other, *, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
    d1 = input.dim()
    d2 = other.dim()
    shape = None
    if d1 == 1 and d2 == 1:
        shape = None
    elif d1 == 2 and d2 == 2:
        shape = (input.size(0), other.size(1))
    elif d1 == 1 and d2 == 2:
        shape = (other.size(1),)
    elif d1 == 2 and d1 == 1:
        shape = (input.size(0),)
    else:
        max_length = max(input.dim(), other.dim())
        shape1 = list(input.shape)
        shape2 = list(other.shape)
        if d1 == 1:
            shape1 = [1] + shape1
        if d2 == 1:
            shape2.append(1)
        shape1 = [-1] * (max_length - d1) + list(input.shape)
        shape2 = [-1] * (max_length - d2) + list(other.shape)
        shape = []
        for i in range(max_length):
            shape.append(max(shape1[i], shape2[i]))
        shape[-2] = shape1[-2]
        shape[-1] = shape2[-1]
        if d1 == 1:
            shape.pop(-2)
        if d2 == 1:
            shape.pop(-1)
    if shape is None:
        return torch.tensor(0.0, device="meta")
    return torch.empty(*shape, device="meta")


347
348
def torch_bmm(input, mat2, *, out=None):
    if out is not None:
349
        raise ValueError("Don't support in-place bmm for MetaTensor analysis")
350
351
352
353
354
    batch_size, n, m = input.shape
    _, _, p = mat2.shape
    return torch.empty(batch_size, n, p, device="meta")


355
356
357
358
359
360
361
362
363
364
def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):
    if out is not None:
        raise ValueError("Don't support in-place baddbmm for MetaTensor analysis")
    return torch_bmm(batch1, batch2)


def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None):
    return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out)


365
366
367
368
369
370
371
def torch_einsum(equation, *operands):
    # TODO: infer shape without performing the computation, this might be quite hard.
    concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
    return torch.einsum(equation, *concrete_operands).to("meta")


def torch_tensor_repeat(self, *sizes):
Michael Benayoun's avatar
Michael Benayoun committed
372
373
374
375
376
377
    shape = list(self.shape)
    for i, x in enumerate(sizes):
        shape[i] *= x
    return torch.empty(shape, device="meta")


378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
def torch_repeat_interleave(*args, dim=None, output_size=None):
    num_args = len(args)
    if num_args == 1:
        shape = [output_size if output_size is not None else args[0].sum()]
    else:
        shape = list(args[0].shape)
        if dim is None:
            if num_args > 2:
                dim = args[2]
            else:
                shape = [sum(shape)]
                dim = 0
        repeats = args[1]
        if isinstance(repeats, int) or torch.numel(repeats) == 1:
            shape[dim] *= int(repeats)
        else:
            shape[dim] = output_size if output_size is not None else repeats.sum()
    return torch.empty(*shape, device="meta")


Michael Benayoun's avatar
Michael Benayoun committed
398
399
400
401
402
403
404
def torch_index_select(input, dim, index, *, out=None):
    shape = list(input.shape)
    shape[dim] = len(index)
    return torch.empty(*shape, device="meta")


def torch_tensor_index_select(self, dim, index):
405
    return torch_index_select(self, dim, index)
Michael Benayoun's avatar
Michael Benayoun committed
406
407


408
409
410
411
412
413
414
415
416
417
def torch_gather(input, dim, index, *, sparse_grad=False, out=None):
    shape = list(input.shape)
    shape[dim] = index.shape[dim]
    return torch.empty(*shape, device="meta")


def torch_tensor_gather(self, dim, index):
    return torch_gather(self, dim, index)


418
419
420
421
def torch_roll(input, shifts, dims=None):
    return input


422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def torch_flip(input, dims):
    return input


def torch_tensor_flip(self, dims):
    return self


def torch_nn_conv1d(self, input):
    l_in = input.shape[-1]
    shape = None
    padding = self.padding
    if padding == "valid":
        padding = (0, 0)
    if padding == "same":
        shape = list(input.shape)
    if shape is None:
        shape = list(input.shape)
        l_out = math.floor(
            (l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
        )
        shape[-1] = l_out
    shape[-2] = self.out_channels
    return torch.empty(shape, device="meta")


448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def torch_nn_conv2d(self, input):
    h_in, w_in = input.shape[-2:]
    shape = None
    padding = self.padding
    if padding == "valid":
        padding = (0, 0)
    if padding == "same":
        shape = list(input.shape)
    if shape is None:
        shape = list(input.shape)
        h_out = math.floor(
            (h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
        )
        w_out = math.floor(
            (w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
        )
        shape[-2:] = [h_out, w_out]
    shape[-3] = self.out_channels
    return torch.empty(shape, device="meta")


469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
def torch_squeeze(input, dim=None):
    shape = list(input.shape)
    if dim is not None:
        if dim < 0:
            dim = input.dim() + dim
        if shape[dim] == 1:
            shape.pop(dim)
    else:
        new_shape = []
        for dim_value in shape:
            if dim_value == 1:
                continue
            new_shape.append(dim_value)
        shape = new_shape
    return torch.empty(shape, device="meta")


def torch_tensor_squeeze(self, dim=None):
    return torch_squeeze(self, dim)


490
491
492
493
494
495
496
497
498
499
500
501
def torch_unsqueeze(input, dim):
    shape = list(input.shape)
    if dim < 0:
        dim = input.dim() + 1 + dim
    shape.insert(dim, 1)
    return torch.empty(shape, device="meta")


def torch_tensor_unsqueeze(self, dim):
    return torch_unsqueeze(self, dim)


502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
def torch_unique_consecutive(input, **kwargs):
    output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
    if isinstance(output, torch.Tensor):
        return output.to("meta")
    else:
        return tuple(map(output, lambda x: x.to("meta")))


def torch_nn_functional_one_hot(tensor, num_classes=-1):
    if num_classes < 0:
        raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
    shape = list(tensor.shape) + [num_classes]
    return torch.empty(shape, device="meta")


Michael Benayoun's avatar
Michael Benayoun committed
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
def torch_nn_mseloss(self, input, target):
    if self.reduction == "none":
        shape = target.shape
    else:
        shape = (1,)
    return torch.empty(shape, device="meta")


def torch_nn_crossentropyloss(self, input, target):
    if self.reduction == "none":
        shape = target.shape
    else:
        shape = (1,)
    return torch.empty(shape, device="meta")


def torch_nn_bcewithlogitsloss(self, input, target):
    if self.reduction == "none":
        shape = target.shape
    else:
        shape = (1,)
    return torch.empty(shape, device="meta")


541
def operator_getitem(a, b):
542
543
544
545
546
547
548
549
    def to_concrete(t):
        if isinstance(t, torch.Tensor):
            concrete = torch.ones_like(t, device="cpu")
            if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
                concrete = concrete.to(torch.int64)
            return concrete
        return t

550
551
    if isinstance(a, torch.Tensor):
        # TODO: infer shape without performing the computation.
552
553
554
555
        if isinstance(b, tuple):
            b = tuple(map(to_concrete, b))
        else:
            b = to_concrete(b)
556
557
558
559
        return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
    return operator.getitem(a, b)


Michael Benayoun's avatar
Michael Benayoun committed
560
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
561
    torch.nn.Embedding: torch_nn_embedding,
562
    torch.nn.functional.embedding: torch_nn_functional_embedding,
563
    torch.nn.LayerNorm: torch_nn_layernorm,
564
    torch.nn.GroupNorm: torch_nn_groupnorm,
565
566
567
568
569
570
571
    torch.nn.Linear: torch_nn_linear,
    torch.relu: torch_relu,
    torch.nn.functional.relu: torch_nn_functional_relu,
    torch.nn.ReLU: torch_nn_relu,
    torch.where: torch_where,
    torch.abs: torch_abs,
    torch.arange: torch_arange,
572
    torch.full: torch_full,
573
574
575
576
577
578
    torch.cat: torch_cat,
    torch.stack: torch_stack,
    torch.add: torch_add,
    torch.mul: torch_mul,
    torch.Tensor.mul: torch_tensor_mul,
    torch.matmul: torch_matmul,
579
    torch.bmm: torch_bmm,
580
581
    torch.baddbmm: torch_baddbmm,
    torch.Tensor.baddbmm: torch_tensor_baddbmm,
582
583
    torch.einsum: torch_einsum,
    torch.Tensor.repeat: torch_tensor_repeat,
584
    torch.repeat_interleave: torch_repeat_interleave,
585
    torch.roll: torch_roll,
586
587
588
589
    torch.flip: torch_flip,
    torch.Tensor.flip: torch_tensor_flip,
    torch.index_select: torch_index_select,
    torch.Tensor.index_select: torch_tensor_index_select,
590
591
    torch.gather: torch_gather,
    torch.Tensor.gather: torch_tensor_gather,
592
    torch.nn.Conv1d: torch_nn_conv1d,
593
    torch.nn.Conv2d: torch_nn_conv2d,
594
595
    torch.squeeze: torch_squeeze,
    torch.Tensor.squeeze: torch_tensor_squeeze,
596
597
    torch.unsqueeze: torch_unsqueeze,
    torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
598
599
    torch.unique_consecutive: torch_unique_consecutive,
    torch.nn.functional.one_hot: torch_nn_functional_one_hot,
Michael Benayoun's avatar
Michael Benayoun committed
600
601
602
    torch.nn.MSELoss: torch_nn_mseloss,
    torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
    torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
603
    operator.getitem: operator_getitem,
Michael Benayoun's avatar
Michael Benayoun committed
604
605
606
}


607
608
class HFProxy(Proxy):
    """
Michael Benayoun's avatar
Michael Benayoun committed
609
    Proxy that uses metadata to handle data-dependent control-flow.
610
611
    """

Michael Benayoun's avatar
Michael Benayoun committed
612
613
    def install_metadata(self, metadata):
        self._metadata = metadata
614
615
616

    @property
    def shape(self):
Michael Benayoun's avatar
Michael Benayoun committed
617
        return self.tracer.create_proxy("call_method", "size", (self,), {})
618

Michael Benayoun's avatar
Michael Benayoun committed
619
620
621
622
623
    @property
    def device(self):
        # Hack so we can track when devices are used. During meta-tensor propagation,
        # replace these values with a constant 'meta'
        return MetaDeviceAttribute(self, "device")
624

625
    def __len__(self):
Michael Benayoun's avatar
Michael Benayoun committed
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
        if hasattr(self, "_metadata") and self._metadata is not None:
            return len(self._metadata)
        return super().__len__()

    def __bool__(self):
        if hasattr(self, "_metadata") and self._metadata is not None:
            return self._metadata
        return super().__bool__()

    def __getattr__(self, k):
        if k == "_metadata":
            return self.__getattribute__(k)
        # note: not added to the graph yet, if this is a method call
        # we peephole optimize to the method invocation
        return HFAttribute(self, k)
641

642
    def __setitem__(self, indices, values):
643
        return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
644

Michael Benayoun's avatar
Michael Benayoun committed
645
    def __contains__(self, key):
646
647
        if hasattr(self, "_metadata") and self._metadata is not None:
            return key in self._metadata
Michael Benayoun's avatar
Michael Benayoun committed
648
        return super().__contains__(key)
649
650


Michael Benayoun's avatar
Michael Benayoun committed
651
652
653
654
655
656
class HFAttribute(HFProxy):
    def __init__(self, root, attr: str):
        self.root = root
        self.attr = attr
        self.tracer = root.tracer
        self._node = None
657

658
659
660
        if hasattr(self.root, "_metadata"):
            self.install_metadata(getattr(self.root._metadata, attr))

Michael Benayoun's avatar
Michael Benayoun committed
661
662
663
664
665
    @property
    def node(self):
        # the node for attributes is added lazily, since most will just be method calls
        # which do not rely on the getitem call
        if self._node is None:
666
            self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node
Michael Benayoun's avatar
Michael Benayoun committed
667
        return self._node
668

Michael Benayoun's avatar
Michael Benayoun committed
669
670
    def __call__(self, *args, **kwargs):
        return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
671
672


Michael Benayoun's avatar
Michael Benayoun committed
673
674
class MetaDeviceAttribute(HFAttribute):
    pass
675

676

Michael Benayoun's avatar
Michael Benayoun committed
677
678
679
680
681
682
683
684
685
def _proxies_to_metas(v):
    """Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
    if isinstance(v, MetaDeviceAttribute):
        return "meta"
    if isinstance(v, torch.fx.Proxy):
        if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
            raise RuntimeError(f"No metadata was found for {v}")
        return v._metadata
    return v
686

687

Michael Benayoun's avatar
Michael Benayoun committed
688
689
690
691
def _gen_constructor_wrapper(target):
    @functools.wraps(target)
    def wrapper(*args, **kwargs):
        proxy = None
692

Michael Benayoun's avatar
Michael Benayoun committed
693
694
695
696
        def check_has_proxy(v):
            if isinstance(v, Proxy):
                nonlocal proxy
                proxy = v
697

Michael Benayoun's avatar
Michael Benayoun committed
698
699
        torch.fx.node.map_aggregate(args, check_has_proxy)
        torch.fx.node.map_aggregate(kwargs, check_has_proxy)
700

Michael Benayoun's avatar
Michael Benayoun committed
701
702
703
704
705
706
        if proxy is not None:
            return proxy.tracer.create_proxy("call_function", target, args, kwargs)
        else:
            return target(*args, **kwargs)

    return wrapper, target
707
708


709
710
711
712
713
714
715
716
717
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
    if forbidden_values is None:
        forbidden_values = []
    value = random.randint(low, high)
    while value in forbidden_values:
        value = random.randint(low, high)
    return value


718
719
class HFTracer(Tracer):
    """
720
721
    Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
    regular PyTorch torch.fx.Proxy.
722
723
    """

724
725
    # Feature flag for proxying accesses to buffer values
    proxy_buffer_attributes: bool = True
Michael Benayoun's avatar
Michael Benayoun committed
726
    allow_insert_stateless_mods: bool = True
727
728
729
730
731
732
733
734
735
736
737
738
    _TORCH_METHODS_TO_PATCH = [
        "arange",
        "zeros",
        "ones",
        "full",
        "full_like",
        "eye",
        "empty",
        "tensor",
        "clamp",
        "finfo",
    ]
739
    supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
740

741
742
    def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
        super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
743
744
745

        if not is_torch_fx_available():
            raise ImportError(
746
                f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
747
748
749
                f"{TORCH_FX_REQUIRED_VERSION} is supported."
            )

750
751
752
    def _generate_dummy_input(
        self, model: PreTrainedModel, input_name: str, shape: List[int]
    ) -> Dict[str, torch.Tensor]:
753
        """Generates dummy input for model inference recording."""
754
755
        # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
        # from pickle, or from the "__class__" attribute in the general case.
756
        model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
757
        device = model.device
758
        inputs_dict = {}
759
760

        if input_name in ["labels", "start_positions", "end_positions"]:
761
            batch_size = shape[0]
762
763
764
765
            if model_class_name in [
                *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
                *get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
                *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
766
                *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
767
768
                *get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
            ]:
769
                inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
770
771
            elif model_class_name in [
                *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
772
                *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
773
                "XLNetForQuestionAnswering",
774
            ]:
775
776
                inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
                inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
777
            elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
Michael Benayoun's avatar
Michael Benayoun committed
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
                if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
                    raise ValueError(
                        "Could not retrieve the problem type for the sequence classification task, please set "
                        'model.config.problem_type to one of the following values: "regression", '
                        '"single_label_classification", or "multi_label_classification".'
                    )

                if model.config.problem_type == "regression":
                    labels_shape = (batch_size, model.config.num_labels)
                    labels_dtype = torch.float32
                elif model.config.problem_type == "single_label_classification":
                    labels_shape = (batch_size,)
                    labels_dtype = torch.long
                elif model.config.problem_type == "multi_label_classification":
                    labels_shape = (batch_size, model.config.num_labels)
                    labels_dtype = torch.float32
                else:
                    raise ValueError(
                        'Expected model.config.problem_type to be either: "regression", "single_label_classification"'
                        f', or "multi_label_classification", but "{model.config.problem_type}" was provided.'
                    )
                inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)

801
802
803
804
805
806
            elif model_class_name in [
                *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
                *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
                *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
                *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
                *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
donguk.lim's avatar
donguk.lim committed
807
                *get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES),
808
                "GPT2DoubleHeadsModel",
809
810
                "PeftModelForCausalLM",
                "PeftModelForSeq2SeqLM",
811
            ]:
812
                inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
813
814
            elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]:
                inputs_dict["labels"] = torch.zeros(shape, dtype=torch.float32, device=device)
815
            else:
816
817
818
                raise NotImplementedError(
                    f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
                )
819
820
        elif "pixel_values" in input_name:
            batch_size = shape[0]
821
822
823
824
825
826
827
            image_size = getattr(model.config, "image_size", None)
            if image_size is None:
                if hasattr(model.config, "vision_config"):
                    image_size = model.config.vision_config.image_size
                elif hasattr(model.config, "encoder"):
                    image_size = model.config.encoder.image_size
                else:
828
                    image_size = (_generate_random_int(), _generate_random_int())
829
830
831

            # If no num_channels is in the config, use some arbitrary value.
            num_channels = getattr(model.config, "num_channels", 3)
832
833
834
835
            if not isinstance(image_size, collections.abc.Iterable):
                image_size = (image_size, image_size)
            height, width = image_size
            inputs_dict[input_name] = torch.zeros(
836
                batch_size, num_channels, height, width, dtype=torch.float32, device=device
837
            )
838
839
840
841
842
843
        elif "bbox" in input_name:
            inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
        elif "input_features" in input_name:
            inputs_dict[input_name] = torch.zeros(
                *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
            )
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
        elif "visual_feats" in input_name:
            inputs_dict[input_name] = torch.zeros(
                shape
                + [
                    model.config.visual_feat_dim,
                ],
                dtype=torch.float,
                device=device,
            )
        elif "visual_pos" in input_name:
            inputs_dict[input_name] = torch.zeros(
                shape
                + [
                    model.config.visual_pos_dim,
                ],
                dtype=torch.float,
                device=device,
            )
862
863
        elif "inputs" in input_name:
            inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
864
865
866
867
868
        elif "input_values" in input_name:
            batch_size, _ = shape
            # Generating big sequence length for audio inputs.
            seq_length = _generate_random_int(low=10000, high=20000)
            inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
869
        elif "mask" in input_name or "ids" in input_name:
870
            inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
871
        else:
872
873
            shape_with_hidden_size = shape + [model.config.hidden_size]
            inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
874
875
876

        return inputs_dict

Michael Benayoun's avatar
Michael Benayoun committed
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
    def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
        rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)

        if kind == "placeholder" and target in self.meta_args:
            rv.install_metadata(self.meta_args[target])
            return rv

        if target in self.orig_fns:
            # NOTE: tensor constructors in PyTorch define the `device` argument as
            # *kwargs-only*. That is why this works. If you add methods to
            # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
            # this will break and you will likely see issues where we cannot infer
            # the size of the output.
            if "device" in kwargs:
                kwargs["device"] = "meta"

        try:
            args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
            kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)

            if kind == "call_function":
                meta_target = _MANUAL_META_OVERRIDES.get(target, target)
                meta_out = meta_target(*args_metas, **kwargs_metas)
900
901
                if isinstance(meta_out, torch.Tensor):
                    meta_out = meta_out.to(device="meta")
Michael Benayoun's avatar
Michael Benayoun committed
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
            elif kind == "call_method":
                method = getattr(args_metas[0].__class__, target)
                meta_target = _MANUAL_META_OVERRIDES.get(method, method)
                meta_out = meta_target(*args_metas, **kwargs_metas)
            elif kind == "call_module":
                if not hasattr(self, "orig_forward"):
                    raise AttributeError(f"{self} does not have an attribute called orig_forward")
                self._disable_module_getattr = True
                try:
                    mod = self.root.get_submodule(target)
                    mod_type = type(mod)
                    if mod_type in _MANUAL_META_OVERRIDES:
                        meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
                    else:
                        meta_out = self.orig_forward(*args_metas, **kwargs_metas)
                finally:
                    self._disable_module_getattr = False
            elif kind == "get_attr":
                self._disable_module_getattr = True
                try:
                    attr_itr = self.root
                    atoms = target.split(".")
                    for atom in atoms:
                        attr_itr = getattr(attr_itr, atom)
                    if isinstance(attr_itr, torch.Tensor):
                        meta_out = attr_itr.to(device="meta")
                    else:
                        meta_out = attr_itr
                finally:
                    self._disable_module_getattr = False
            else:
                return rv
934

Michael Benayoun's avatar
Michael Benayoun committed
935
936
937
938
            if not isinstance(rv, Proxy):
                raise ValueError("Don't support composite output yet")
            rv.install_metadata(meta_out)
        except Exception as e:
939
940
            if _IS_IN_DEBUG_MODE:
                warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
941

Michael Benayoun's avatar
Michael Benayoun committed
942
        return rv
943

944
    # Replaced by .getattr from PyTorch 1.13
Michael Benayoun's avatar
Michael Benayoun committed
945
946
947
948
    def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
        if getattr(self, "_disable_module_getattr", False):
            return attr_val
        else:
949

950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
            def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
                for n, p in collection_to_search:
                    if attr_val is p:
                        if n not in parameter_proxy_cache:
                            kwargs = {}
                            if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
                                kwargs["proxy_factory_fn"] = (
                                    None
                                    if not self.param_shapes_constant
                                    else lambda node: ParameterProxy(self, node, n, attr_val)
                                )
                            val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs)  # type: ignore[arg-type]
                            parameter_proxy_cache[n] = val_proxy
                        return parameter_proxy_cache[n]
                return None

            if isinstance(attr_val, torch.nn.Parameter):
                maybe_parameter_proxy = maybe_get_proxy_for_attr(
                    attr_val, self.root.named_parameters(), parameter_proxy_cache
                )
                if maybe_parameter_proxy is not None:
                    return maybe_parameter_proxy

            if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
                maybe_buffer_proxy = maybe_get_proxy_for_attr(
                    attr_val, self.root.named_buffers(), parameter_proxy_cache
                )
                if maybe_buffer_proxy is not None:
                    return maybe_buffer_proxy

            return attr_val
981

982
983
984
985
    # Needed for PyTorch 1.13+
    def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
        return self._module_getattr(attr, attr_val, parameter_proxy_cache)

Michael Benayoun's avatar
Michael Benayoun committed
986
987
988
    def call_module(self, m, forward, args, kwargs):
        self.orig_forward = forward
        return super().call_module(m, forward, args, kwargs)
989

Michael Benayoun's avatar
Michael Benayoun committed
990
991
    def proxy(self, node):
        return HFProxy(node, self)
992

993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
    def trace(
        self,
        root: Union[torch.nn.Module, Callable[..., Any]],
        concrete_args: Optional[Dict[str, Any]] = None,
        dummy_inputs: Optional[Dict[str, Any]] = None,
        complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
    ) -> Graph:
        """
        Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
        `torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
        the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
        `torch.nn.Module` instance to use as the root and add embedded constants to.

        Args:
            root (`torch.nn.Module` or  `Callable`):
                Either a `torch.nn.Module`` or a function to be traced through. If root is not a
                [`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
            concrete_args (`Dict[str, Any], *optional*):
                Concrete arguments that should not be treated as Proxies
            dummy_inputs (`Dict[str, Any]`, *optional*):
                The dummy inputs needed to handle data-dependent control-flow if `root` is not a
                [`~transformers.PreTrainedModel`]. It can also be used when `root` is a
                [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
            complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
                If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
                `dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.

        Returns:
            `torch.fx.Graph`:
                A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.

        """
        sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)

1027
1028
1029
        if concrete_args is None:
            concrete_args = {}

1030
1031
1032
1033
1034
1035
        if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
            for param in sig.parameters.values():
                if param.name in dummy_inputs:
                    continue
                if param.default is inspect.Parameter.empty:
                    raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
1036
1037
1038
1039
1040
1041
1042
            concrete_args.update(
                {
                    p.name: p.default
                    for p in sig.parameters.values()
                    if (p.name not in dummy_inputs and p.name not in concrete_args)
                }
            )
1043

1044
1045
        input_names = sig.parameters.keys() - concrete_args.keys()

Michael Benayoun's avatar
Michael Benayoun committed
1046
1047
1048
1049
        # Creating a random input shape to generate dummy inputs.
        batch_size = _generate_random_int()
        sequence_length = _generate_random_int()
        shape = [batch_size, sequence_length]
1050

1051
        if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
Michael Benayoun's avatar
Michael Benayoun committed
1052
1053
            num_choices = _generate_random_int(low=2, high=5)
            shape.insert(1, num_choices)
1054

1055
        inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
Michael Benayoun's avatar
Michael Benayoun committed
1056
        for input_name in input_names:
1057
1058
1059
1060
            if input_name in inputs:
                continue
            # We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
            # be able to use HFTracer._generate_dummy_input.
1061
            if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
1062
                ("_deserialize_graph_module", "_CodeOnlyModule")
1063
            ):
1064
1065
1066
1067
1068
1069
                inputs.update(self._generate_dummy_input(root, input_name, shape))
            else:
                raise RuntimeError(
                    f"Could not generate input named {input_name} for because root is not a"
                    " transformers.PreTrainedModel."
                )
1070

1071
1072
1073
1074
        concrete_metas = {
            input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
            for input_name, input_ in inputs.items()
        }
1075
1076
1077
        for param in sig.parameters.values():
            if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
                concrete_metas[f"**{param.name}"] = {}
Michael Benayoun's avatar
Michael Benayoun committed
1078
1079
1080
1081
1082
        self.meta_args = concrete_metas
        self.patched_torch_methods = {
            target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
        }
        self.orig_fns = set()
1083

Michael Benayoun's avatar
Michael Benayoun committed
1084
1085
1086
        for name, (wrapper, orig) in self.patched_torch_methods.items():
            setattr(torch, name, wrapper)
            self.orig_fns.add(orig)
1087

Michael Benayoun's avatar
Michael Benayoun committed
1088
1089
1090
1091
1092
        try:
            self.graph = super().trace(root, concrete_args=concrete_args)
        finally:
            for name, (_, orig) in self.patched_torch_methods.items():
                setattr(torch, name, orig)
1093

1094
1095
        # This is necessary because concrete args are added as input to the traced module since
        # https://github.com/pytorch/pytorch/pull/55888.
1096
        for node in self.graph.nodes:
1097
1098
1099
1100
            if node.op == "placeholder":
                # Removing default values for inputs as the forward pass will fail with them.
                if node.target in input_names:
                    node.args = ()
1101
1102
1103
                    # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
                    # It cannot infer on the attributes and methods the input should have, and fails.
                    node.type = torch.Tensor
1104
1105
                # It is a concrete arg so it is not used and should be removed.
                else:
1106
1107
1108
1109
1110
1111
1112
1113
1114
                    to_visit = [node]
                    to_delete = collections.OrderedDict()
                    while to_visit:
                        n = to_visit.pop(0)
                        to_delete[n] = None
                        to_visit += list(n.users.keys())

                    for user in reversed(to_delete.keys()):
                        self.graph.erase_node(user)
1115

1116
1117
1118
1119
1120
            # TODO: solves GraphModule creation.
            # Without this, return type annotation "Tuple" is causing code execution failure.
            if node.op == "output":
                node.type = None

1121
        return self.graph
1122

Michael Benayoun's avatar
Michael Benayoun committed
1123
1124
1125
1126
1127
1128
1129
    def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
        """
        Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module
        because its attributes are input-dependent.
        """
        return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())

1130
    def _insert_module_as_submodule(self, mod: nn.Module) -> str:
1131
1132
1133
        """
        Helper method which tries to insert a module that was not declared as submodule.
        """
Michael Benayoun's avatar
Michael Benayoun committed
1134
1135
1136
1137
        # If one of the module attributes is a Proxy, it means that its instantiation is input-dependent.
        # It is not possible to insert such modules, those should be traced through.
        if self._stateless_mod_instanciation_depends_on_proxies(mod):
            return ""
1138
1139
1140
        idx = 0
        mod_name = mod.__class__.__name__.lower()
        path = f"{mod_name}_{idx}"
Michael Benayoun's avatar
Michael Benayoun committed
1141
        already_inserted = False
1142
        while hasattr(self.root, path):
Michael Benayoun's avatar
Michael Benayoun committed
1143
1144
1145
            if getattr(self.root, path) is mod:
                already_inserted = True
                break
1146
1147
1148
            path = f"{mod_name}_{idx}"
            idx += 1

Michael Benayoun's avatar
Michael Benayoun committed
1149
1150
1151
        # No need to add multiple instances of the same module.
        if not already_inserted:
            self.root.add_module(path, mod)
1152
1153
        return path

1154
    def path_of_module(self, mod: nn.Module) -> str:
1155
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1156
1157
1158
        Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has
        a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the
        string "foo.bar".
1159
1160

        Args:
1161
            mod (str): The `Module` to retrieve the qualified name for.
1162
        """
Michael Benayoun's avatar
Michael Benayoun committed
1163
1164
1165
1166
        try:
            return super().path_of_module(mod)
        except NameError as e:
            if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
1167
                path = self._insert_module_as_submodule(mod)
Michael Benayoun's avatar
Michael Benayoun committed
1168
1169
                return path
            raise e
1170

Michael Benayoun's avatar
Michael Benayoun committed
1171
1172
1173
1174
    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
        return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module(
            m, module_qualified_name
        )
1175

1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
    @compatibility(is_backward_compatible=True)
    def keys(self, obj: "Proxy") -> Any:
        """Called when a proxy object is has the keys() method called.
        This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in
        your custom tracer.
        """
        attribute = HFAttribute(obj, "keys")()
        if obj.node.target == "**kwargs":
            return attribute._metadata
        return attribute

1187

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
def get_concrete_args(model: nn.Module, input_names: List[str]):
    sig = inspect.signature(model.forward)

    if not (set(input_names) <= set(sig.parameters.keys())):
        formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
        formatted_allowed_input_names = ", ".join(sig.parameters.keys())
        raise ValueError(
            f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
            f" {formatted_allowed_input_names}"
        )

    return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}


def check_if_model_is_supported(model: PreTrainedModel):
    if model.__class__.__name__ not in _SUPPORTED_MODELS:
        supported_model_names = ", ".join(_SUPPORTED_MODELS)
        raise NotImplementedError(
            f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
        )


1210
1211
1212
def symbolic_trace(
    model: PreTrainedModel,
    input_names: Optional[List[str]] = None,
1213
    disable_check: bool = False,
1214
    tracer_cls: Type[HFTracer] = HFTracer,
1215
1216
1217
1218
1219
) -> GraphModule:
    """
    Performs symbolic tracing on the model.

    Args:
1220
        model ([`PretrainedModel`]):
1221
            The model to trace.
1222
        input_names (`List[str]`, *optional*):
1223
            The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
1224
1225
        disable_check (`bool`, *optional*, defaults to `False`):
            If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
1226
1227
        tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
            The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.
1228
1229

    Returns:
1230
1231
1232
1233
        `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.

    Example:

1234
1235
        ```python
        from transformers.utils.fx import symbolic_trace
Sylvain Gugger's avatar
Sylvain Gugger committed
1236

1237
1238
1239
        traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
        ```
    """
1240
1241
1242
    if input_names is None:
        input_names = model.dummy_inputs.keys()

1243
    input_names = list(input_names)
1244
    concrete_args = get_concrete_args(model, input_names)
1245

1246
1247
    if not disable_check:
        check_if_model_is_supported(model)
1248
1249

    # Tracing.
1250
    tracer = tracer_cls()
1251
1252
1253
    traced_graph = tracer.trace(model, concrete_args=concrete_args)
    traced = torch.fx.GraphModule(model, traced_graph)

1254
1255
1256
1257
1258
1259
    traced.config = model.config
    # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
    # _generate_dummy_input, where the model class is needed.
    traced.class_for_deserialization = model.__class__
    traced.device = model.device

1260
    return traced