fx.py 39.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Michael Benayoun's avatar
Michael Benayoun committed
16
import builtins
17
import collections
18
import functools
19
import inspect
20
import math
21
import operator
22
import random
Michael Benayoun's avatar
Michael Benayoun committed
23
import warnings
24
from typing import Any, Callable, Dict, List, Optional, Type, Union
25
26

import torch
27
from packaging import version
28
from torch import nn
Michael Benayoun's avatar
Michael Benayoun committed
29
from torch.fx import Graph, GraphModule, Proxy, Tracer
30
from torch.fx.proxy import ParameterProxy
31

32
from .. import PretrainedConfig, PreTrainedModel, logging
33
from ..models.auto import get_values
34
from ..models.auto.modeling_auto import (
35
    MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
36
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
37
    MODEL_FOR_CTC_MAPPING_NAMES,
38
39
40
41
42
43
44
45
46
47
48
49
50
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
    MODEL_FOR_MASKED_LM_MAPPING_NAMES,
    MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
    MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
    MODEL_FOR_PRETRAINING_MAPPING_NAMES,
    MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
    MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
    MODEL_MAPPING_NAMES,
)
51
52
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils.versions import importlib_metadata
53
54
55


logger = logging.get_logger(__name__)
56
57


58
def _generate_supported_model_class_names(
59
60
    model_name: Type[PretrainedConfig],
    supported_tasks: Optional[Union[str, List[str]]] = None,
61
) -> List[str]:
62

63
    task_mapping = {
64
65
66
67
68
69
70
71
72
73
74
75
76
        "default": MODEL_MAPPING_NAMES,
        "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
        "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
        "masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
        "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
        "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
        "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
        "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
        "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
        "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
        "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
        "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
        "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
77
78
        "ctc": MODEL_FOR_CTC_MAPPING_NAMES,
        "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
79
80
81
82
83
84
85
    }

    if supported_tasks is None:
        supported_tasks = task_mapping.keys()
    if isinstance(supported_tasks, str):
        supported_tasks = [supported_tasks]

86
    model_class_names = []
87
    for task in supported_tasks:
88
89
90
        class_name = task_mapping[task].get(model_name, None)
        if class_name:
            model_class_names.append(class_name)
91

92
    return model_class_names
93
94
95
96


_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
    "albert",
97
    "bart",
98
    "bert",
99
100
    "blenderbot",
    "blenderbot-small",
101
    "bloom",
102
    "clip",
103
104
    "deberta",
    "deberta-v2",
105
106
107
108
    "distilbert",
    "electra",
    "gpt2",
    "gpt_neo",
109
    "gptj",
110
    "hubert",
111
    "layoutlm",
112
    "lxmert",
113
114
115
116
117
118
    "m2m_100",
    "marian",
    "mbart",
    "megatron-bert",
    "mobilebert",
    "mt5",
119
    "nezha",
120
121
122
    "opt",
    "pegasus",
    "plbart",
123
    "roberta",
124
125
    "speech_to_text",
    "speech_to_text_2",
126
    "swin",
127
128
129
130
    "t5",
    "trocr",
    "vit",
    "xglm",
131
    #    "xlnet",
132
133
134
135
136
]

_REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
    if isinstance(item, dict):
137
        _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
138
    else:
139
        _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
140
141

_SPECIAL_SUPPORTED_MODELS = [
142
143
144
145
146
    "CLIPTextModel",
    "CLIPVisionModel",
    "GPT2DoubleHeadsModel",
    "Speech2Text2Decoder",
    "TrOCRDecoder",
147
148
    # TODO: add support for them as it should be quite easy to do so (small blocking issues).
    # XLNetForQuestionAnswering,
149
]
150
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
151
152


153
def torch_nn_embedding(self, input):
Michael Benayoun's avatar
Michael Benayoun committed
154
155
156
    return torch.empty(*input.shape, self.weight.shape[-1], device="meta")


157
158
159
160
161
162
def torch_nn_functional_embedding(
    input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
):
    return torch.empty(*input.shape, weight.shape[-1], device="meta")


163
def torch_nn_layernorm(self, input):
Michael Benayoun's avatar
Michael Benayoun committed
164
165
166
    return input


167
168
169
170
def torch_nn_groupnorm(self, input):
    return input


171
def torch_nn_linear(self, input):
Michael Benayoun's avatar
Michael Benayoun committed
172
173
174
    return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")


175
def torch_relu(x):
Michael Benayoun's avatar
Michael Benayoun committed
176
177
178
    return x


179
def torch_nn_relu(self, x):
Michael Benayoun's avatar
Michael Benayoun committed
180
181
182
    return x


183
def torch_nn_functional_relu(x, inplace=False):
Michael Benayoun's avatar
Michael Benayoun committed
184
185
186
187
188
    if not inplace:
        raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
    return x


189
def torch_where(condition, x, y):
Michael Benayoun's avatar
Michael Benayoun committed
190
191
192
193
194
    # torch.where returns the broadcasted tensor of condition, x, and y,
    # so hack it by using addition
    return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")


195
196
def torch_abs(input, *, out=None):
    if out is not None:
Michael Benayoun's avatar
Michael Benayoun committed
197
198
199
200
        raise ValueError("Don't support in-place abs for MetaTensor analysis")
    return input


201
def torch_arange(*args, **kwargs):
Michael Benayoun's avatar
Michael Benayoun committed
202
203
204
205
206
207
208
209
210
    n = len(args)
    step = 1
    if n == 1:
        start = 0
        end = args[0]
    elif n == 2:
        start, end = args
    else:
        start, end, step = args
211
212
213
214
215
216
    if isinstance(start, float):
        start = int(start)
    if isinstance(end, float):
        start = int(end)
    if isinstance(step, float):
        step = int(step)
Michael Benayoun's avatar
Michael Benayoun committed
217
218
219
220
221
    step = kwargs.get("step", step)
    dtype = kwargs.get("dtype")
    return torch.empty((end - start) // step, dtype=dtype, device="meta")


222
def torch_cat(tensors, dim=None, axis=None, *, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
223
224
225
226
227
228
229
230
231
232
233
234
235
    if dim is None and axis is None:
        dim = 0
    if dim is None and axis is not None:
        dim = axis
    if dim < 0:
        dim = tensors[0].dim() + dim
    shapes = [t.shape for t in tensors]
    shape = list(shapes[0])
    concatenated_dim = sum(shape[dim] for shape in shapes)
    final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
    return torch.empty(final_shape, device="meta")


236
def torch_stack(tensors, dim=None, axis=None, *, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
237
238
239
240
241
242
243
244
245
246
247
    if dim is None and axis is None:
        dim = 0
    if dim is None and axis is not None:
        dim = axis
    if dim < 0:
        dim = tensors[0].dim() + 1 + dim
    shape = list(tensors[0].shape)
    shape.insert(dim, len(tensors))
    return torch.empty(shape, device="meta")


248
def torch_add(input, other, *, alpha=1, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
249
250
251
252
253
254
255
256
257
258
259
260
261
    if not isinstance(input, torch.Tensor):
        return torch.empty_like(other, device="meta")
    if not isinstance(other, torch.Tensor):
        return torch.empty_like(input, device="meta")
    max_length = max(input.dim(), other.dim())
    input_shape = list(input.shape) + [1] * (max_length - input.dim())
    other_shape = list(other.shape) + [1] * (max_length - other.dim())
    shape = []
    for i in range(max_length):
        shape.append(max(input_shape[i], other_shape[i]))
    return torch.empty(shape, device="meta")


262
263
def torch_mul(input, other, *, out=None):
    return torch_add(input, other, out=out)
Michael Benayoun's avatar
Michael Benayoun committed
264
265


266
267
def torch_tensor_mul(self, other):
    return torch_mul(self, other)
Michael Benayoun's avatar
Michael Benayoun committed
268
269


270
def torch_matmul(input, other, *, out=None):
Michael Benayoun's avatar
Michael Benayoun committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    d1 = input.dim()
    d2 = other.dim()
    shape = None
    if d1 == 1 and d2 == 1:
        shape = None
    elif d1 == 2 and d2 == 2:
        shape = (input.size(0), other.size(1))
    elif d1 == 1 and d2 == 2:
        shape = (other.size(1),)
    elif d1 == 2 and d1 == 1:
        shape = (input.size(0),)
    else:
        max_length = max(input.dim(), other.dim())
        shape1 = list(input.shape)
        shape2 = list(other.shape)
        if d1 == 1:
            shape1 = [1] + shape1
        if d2 == 1:
            shape2.append(1)
        shape1 = [-1] * (max_length - d1) + list(input.shape)
        shape2 = [-1] * (max_length - d2) + list(other.shape)
        shape = []
        for i in range(max_length):
            shape.append(max(shape1[i], shape2[i]))
        shape[-2] = shape1[-2]
        shape[-1] = shape2[-1]
        if d1 == 1:
            shape.pop(-2)
        if d2 == 1:
            shape.pop(-1)
    if shape is None:
        return torch.tensor(0.0, device="meta")
    return torch.empty(*shape, device="meta")


306
307
308
309
310
311
312
313
def torch_bmm(input, mat2, *, out=None):
    if out is not None:
        raise ValueError("Don't support in-place abs for MetaTensor analysis")
    batch_size, n, m = input.shape
    _, _, p = mat2.shape
    return torch.empty(batch_size, n, p, device="meta")


314
315
316
317
318
319
320
def torch_einsum(equation, *operands):
    # TODO: infer shape without performing the computation, this might be quite hard.
    concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
    return torch.einsum(equation, *concrete_operands).to("meta")


def torch_tensor_repeat(self, *sizes):
Michael Benayoun's avatar
Michael Benayoun committed
321
322
323
324
325
326
327
328
329
330
331
332
333
    shape = list(self.shape)
    for i, x in enumerate(sizes):
        shape[i] *= x
    return torch.empty(shape, device="meta")


def torch_index_select(input, dim, index, *, out=None):
    shape = list(input.shape)
    shape[dim] = len(index)
    return torch.empty(*shape, device="meta")


def torch_tensor_index_select(self, dim, index):
334
    return torch_index_select(self, dim, index)
Michael Benayoun's avatar
Michael Benayoun committed
335
336


337
338
339
340
def torch_roll(input, shifts, dims=None):
    return input


341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def torch_flip(input, dims):
    return input


def torch_tensor_flip(self, dims):
    return self


def torch_nn_conv1d(self, input):
    l_in = input.shape[-1]
    shape = None
    padding = self.padding
    if padding == "valid":
        padding = (0, 0)
    if padding == "same":
        shape = list(input.shape)
    if shape is None:
        shape = list(input.shape)
        l_out = math.floor(
            (l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
        )
        shape[-1] = l_out
    shape[-2] = self.out_channels
    return torch.empty(shape, device="meta")


367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
def torch_nn_conv2d(self, input):
    h_in, w_in = input.shape[-2:]
    shape = None
    padding = self.padding
    if padding == "valid":
        padding = (0, 0)
    if padding == "same":
        shape = list(input.shape)
    if shape is None:
        shape = list(input.shape)
        h_out = math.floor(
            (h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
        )
        w_out = math.floor(
            (w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
        )
        shape[-2:] = [h_out, w_out]
    shape[-3] = self.out_channels
    return torch.empty(shape, device="meta")


388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def torch_squeeze(input, dim=None):
    shape = list(input.shape)
    if dim is not None:
        if dim < 0:
            dim = input.dim() + dim
        if shape[dim] == 1:
            shape.pop(dim)
    else:
        new_shape = []
        for dim_value in shape:
            if dim_value == 1:
                continue
            new_shape.append(dim_value)
        shape = new_shape
    return torch.empty(shape, device="meta")


def torch_tensor_squeeze(self, dim=None):
    return torch_squeeze(self, dim)


409
410
411
412
413
414
415
416
417
418
419
420
def torch_unsqueeze(input, dim):
    shape = list(input.shape)
    if dim < 0:
        dim = input.dim() + 1 + dim
    shape.insert(dim, 1)
    return torch.empty(shape, device="meta")


def torch_tensor_unsqueeze(self, dim):
    return torch_unsqueeze(self, dim)


421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
def torch_unique_consecutive(input, **kwargs):
    output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
    if isinstance(output, torch.Tensor):
        return output.to("meta")
    else:
        return tuple(map(output, lambda x: x.to("meta")))


def torch_nn_functional_one_hot(tensor, num_classes=-1):
    if num_classes < 0:
        raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
    shape = list(tensor.shape) + [num_classes]
    return torch.empty(shape, device="meta")


Michael Benayoun's avatar
Michael Benayoun committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
def torch_nn_mseloss(self, input, target):
    if self.reduction == "none":
        shape = target.shape
    else:
        shape = (1,)
    return torch.empty(shape, device="meta")


def torch_nn_crossentropyloss(self, input, target):
    if self.reduction == "none":
        shape = target.shape
    else:
        shape = (1,)
    return torch.empty(shape, device="meta")


def torch_nn_bcewithlogitsloss(self, input, target):
    if self.reduction == "none":
        shape = target.shape
    else:
        shape = (1,)
    return torch.empty(shape, device="meta")


460
def operator_getitem(a, b):
461
462
463
464
465
466
467
468
    def to_concrete(t):
        if isinstance(t, torch.Tensor):
            concrete = torch.ones_like(t, device="cpu")
            if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
                concrete = concrete.to(torch.int64)
            return concrete
        return t

469
470
    if isinstance(a, torch.Tensor):
        # TODO: infer shape without performing the computation.
471
472
473
474
        if isinstance(b, tuple):
            b = tuple(map(to_concrete, b))
        else:
            b = to_concrete(b)
475
476
477
478
        return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
    return operator.getitem(a, b)


Michael Benayoun's avatar
Michael Benayoun committed
479
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
480
    torch.nn.Embedding: torch_nn_embedding,
481
    torch.nn.functional.embedding: torch_nn_functional_embedding,
482
    torch.nn.LayerNorm: torch_nn_layernorm,
483
    torch.nn.GroupNorm: torch_nn_groupnorm,
484
485
486
487
488
489
490
491
492
493
494
495
496
    torch.nn.Linear: torch_nn_linear,
    torch.relu: torch_relu,
    torch.nn.functional.relu: torch_nn_functional_relu,
    torch.nn.ReLU: torch_nn_relu,
    torch.where: torch_where,
    torch.abs: torch_abs,
    torch.arange: torch_arange,
    torch.cat: torch_cat,
    torch.stack: torch_stack,
    torch.add: torch_add,
    torch.mul: torch_mul,
    torch.Tensor.mul: torch_tensor_mul,
    torch.matmul: torch_matmul,
497
    torch.bmm: torch_bmm,
498
499
    torch.einsum: torch_einsum,
    torch.Tensor.repeat: torch_tensor_repeat,
500
    torch.roll: torch_roll,
501
502
503
504
505
    torch.flip: torch_flip,
    torch.Tensor.flip: torch_tensor_flip,
    torch.index_select: torch_index_select,
    torch.Tensor.index_select: torch_tensor_index_select,
    torch.nn.Conv1d: torch_nn_conv1d,
506
    torch.nn.Conv2d: torch_nn_conv2d,
507
508
    torch.squeeze: torch_squeeze,
    torch.Tensor.squeeze: torch_tensor_squeeze,
509
510
    torch.unsqueeze: torch_unsqueeze,
    torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
511
512
    torch.unique_consecutive: torch_unique_consecutive,
    torch.nn.functional.one_hot: torch_nn_functional_one_hot,
Michael Benayoun's avatar
Michael Benayoun committed
513
514
515
    torch.nn.MSELoss: torch_nn_mseloss,
    torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
    torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
516
    operator.getitem: operator_getitem,
Michael Benayoun's avatar
Michael Benayoun committed
517
518
519
}


520
521
class HFProxy(Proxy):
    """
Michael Benayoun's avatar
Michael Benayoun committed
522
    Proxy that uses metadata to handle data-dependent control-flow.
523
524
    """

Michael Benayoun's avatar
Michael Benayoun committed
525
526
    def install_metadata(self, metadata):
        self._metadata = metadata
527
528
529

    @property
    def shape(self):
Michael Benayoun's avatar
Michael Benayoun committed
530
        return self.tracer.create_proxy("call_method", "size", (self,), {})
531

Michael Benayoun's avatar
Michael Benayoun committed
532
533
534
535
536
    @property
    def dtype(self):
        if hasattr(self, "_metadata") and self._metadata is not None:
            return self._metadata.dtype
        return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
537

Michael Benayoun's avatar
Michael Benayoun committed
538
539
540
541
542
    @property
    def device(self):
        # Hack so we can track when devices are used. During meta-tensor propagation,
        # replace these values with a constant 'meta'
        return MetaDeviceAttribute(self, "device")
543

544
    def __len__(self):
Michael Benayoun's avatar
Michael Benayoun committed
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
        if hasattr(self, "_metadata") and self._metadata is not None:
            return len(self._metadata)
        return super().__len__()

    def __bool__(self):
        if hasattr(self, "_metadata") and self._metadata is not None:
            return self._metadata
        return super().__bool__()

    def __getattr__(self, k):
        if k == "_metadata":
            return self.__getattribute__(k)
        # note: not added to the graph yet, if this is a method call
        # we peephole optimize to the method invocation
        return HFAttribute(self, k)
560

561
    def __setitem__(self, indices, values):
562
        return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
563

Michael Benayoun's avatar
Michael Benayoun committed
564
    def __contains__(self, key):
565
566
        if hasattr(self, "_metadata") and self._metadata is not None:
            return key in self._metadata
Michael Benayoun's avatar
Michael Benayoun committed
567
        return super().__contains__(key)
568
569


Michael Benayoun's avatar
Michael Benayoun committed
570
571
572
573
574
575
class HFAttribute(HFProxy):
    def __init__(self, root, attr: str):
        self.root = root
        self.attr = attr
        self.tracer = root.tracer
        self._node = None
576

Michael Benayoun's avatar
Michael Benayoun committed
577
578
579
580
581
582
583
    @property
    def node(self):
        # the node for attributes is added lazily, since most will just be method calls
        # which do not rely on the getitem call
        if self._node is None:
            self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
        return self._node
584

Michael Benayoun's avatar
Michael Benayoun committed
585
586
    def __call__(self, *args, **kwargs):
        return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
587
588


Michael Benayoun's avatar
Michael Benayoun committed
589
590
class MetaDeviceAttribute(HFAttribute):
    pass
591

592

Michael Benayoun's avatar
Michael Benayoun committed
593
594
595
596
597
598
599
600
601
def _proxies_to_metas(v):
    """Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
    if isinstance(v, MetaDeviceAttribute):
        return "meta"
    if isinstance(v, torch.fx.Proxy):
        if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
            raise RuntimeError(f"No metadata was found for {v}")
        return v._metadata
    return v
602

603

Michael Benayoun's avatar
Michael Benayoun committed
604
605
606
607
def _gen_constructor_wrapper(target):
    @functools.wraps(target)
    def wrapper(*args, **kwargs):
        proxy = None
608

Michael Benayoun's avatar
Michael Benayoun committed
609
610
611
612
        def check_has_proxy(v):
            if isinstance(v, Proxy):
                nonlocal proxy
                proxy = v
613

Michael Benayoun's avatar
Michael Benayoun committed
614
615
        torch.fx.node.map_aggregate(args, check_has_proxy)
        torch.fx.node.map_aggregate(kwargs, check_has_proxy)
616

Michael Benayoun's avatar
Michael Benayoun committed
617
618
619
620
621
622
        if proxy is not None:
            return proxy.tracer.create_proxy("call_function", target, args, kwargs)
        else:
            return target(*args, **kwargs)

    return wrapper, target
623
624


625
626
627
628
629
630
631
632
633
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
    if forbidden_values is None:
        forbidden_values = []
    value = random.randint(low, high)
    while value in forbidden_values:
        value = random.randint(low, high)
    return value


634
635
class HFTracer(Tracer):
    """
636
637
    Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
    regular PyTorch torch.fx.Proxy.
638
639
    """

640
641
    # Feature flag for proxying accesses to buffer values
    proxy_buffer_attributes: bool = True
Michael Benayoun's avatar
Michael Benayoun committed
642
    allow_insert_stateless_mods: bool = True
643
    _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
644

645
    def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
646

647
        super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
648
649
650
651
652
653
654
655

        if not is_torch_fx_available():
            torch_version = version.parse(importlib_metadata.version("torch"))
            raise ImportError(
                f"Found an incompatible version of torch. Found version {torch_version}, but only version "
                f"{TORCH_FX_REQUIRED_VERSION} is supported."
            )

656
657
658
    def _generate_dummy_input(
        self, model: PreTrainedModel, input_name: str, shape: List[int]
    ) -> Dict[str, torch.Tensor]:
659
        """Generates dummy input for model inference recording."""
660
661
        # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
        # from pickle, or from the "__class__" attribute in the general case.
662
        model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
663
        device = model.device
664
        inputs_dict = {}
665
666

        if input_name in ["labels", "start_positions", "end_positions"]:
Michael Benayoun's avatar
Michael Benayoun committed
667

668
            batch_size = shape[0]
669
            if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
670
                inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
671
672
673
            elif model_class_name in [
                *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
                "XLNetForQuestionAnswering",
674
            ]:
675
676
                inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
                inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
677
            elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
Michael Benayoun's avatar
Michael Benayoun committed
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
                if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
                    raise ValueError(
                        "Could not retrieve the problem type for the sequence classification task, please set "
                        'model.config.problem_type to one of the following values: "regression", '
                        '"single_label_classification", or "multi_label_classification".'
                    )

                if model.config.problem_type == "regression":
                    labels_shape = (batch_size, model.config.num_labels)
                    labels_dtype = torch.float32
                elif model.config.problem_type == "single_label_classification":
                    labels_shape = (batch_size,)
                    labels_dtype = torch.long
                elif model.config.problem_type == "multi_label_classification":
                    labels_shape = (batch_size, model.config.num_labels)
                    labels_dtype = torch.float32
                else:
                    raise ValueError(
                        'Expected model.config.problem_type to be either: "regression", "single_label_classification"'
                        f', or "multi_label_classification", but "{model.config.problem_type}" was provided.'
                    )
                inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)

701
702
703
            elif model_class_name in [
                *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
                *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
704
705
            ]:
                inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
706
707
708
709
710
711
712
            elif model_class_name in [
                *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
                *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
                *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
                *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
                *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
                "GPT2DoubleHeadsModel",
713
            ]:
714
                inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
715
            else:
716
                raise NotImplementedError(f"{model_class_name} not supported yet.")
717
718
        elif "pixel_values" in input_name:
            batch_size = shape[0]
719
720
721
722
723
724
725
726
727
728
729
            image_size = getattr(model.config, "image_size", None)
            if image_size is None:
                if hasattr(model.config, "vision_config"):
                    image_size = model.config.vision_config.image_size
                elif hasattr(model.config, "encoder"):
                    image_size = model.config.encoder.image_size
                else:
                    raise AttributeError('Could not find the "image_size" field in the model config')

            # If no num_channels is in the config, use some arbitrary value.
            num_channels = getattr(model.config, "num_channels", 3)
730
731
732
733
            if not isinstance(image_size, collections.abc.Iterable):
                image_size = (image_size, image_size)
            height, width = image_size
            inputs_dict[input_name] = torch.zeros(
734
                batch_size, num_channels, height, width, dtype=torch.float32, device=device
735
            )
736
737
738
739
740
741
        elif "bbox" in input_name:
            inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
        elif "input_features" in input_name:
            inputs_dict[input_name] = torch.zeros(
                *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
            )
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        elif "visual_feats" in input_name:
            inputs_dict[input_name] = torch.zeros(
                shape
                + [
                    model.config.visual_feat_dim,
                ],
                dtype=torch.float,
                device=device,
            )
        elif "visual_pos" in input_name:
            inputs_dict[input_name] = torch.zeros(
                shape
                + [
                    model.config.visual_pos_dim,
                ],
                dtype=torch.float,
                device=device,
            )
760
761
        elif "inputs" in input_name:
            inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
762
763
764
765
766
        elif "input_values" in input_name:
            batch_size, _ = shape
            # Generating big sequence length for audio inputs.
            seq_length = _generate_random_int(low=10000, high=20000)
            inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
767
        elif "mask" in input_name or "ids" in input_name:
768
            inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
769
        else:
770
771
            shape_with_hidden_size = shape + [model.config.hidden_size]
            inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
772
773
774

        return inputs_dict

Michael Benayoun's avatar
Michael Benayoun committed
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
    def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
        rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)

        if kind == "placeholder" and target in self.meta_args:
            rv.install_metadata(self.meta_args[target])
            return rv

        if target in self.orig_fns:
            # NOTE: tensor constructors in PyTorch define the `device` argument as
            # *kwargs-only*. That is why this works. If you add methods to
            # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
            # this will break and you will likely see issues where we cannot infer
            # the size of the output.
            if "device" in kwargs:
                kwargs["device"] = "meta"

        try:
            args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
            kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)

            if kind == "call_function":
                meta_target = _MANUAL_META_OVERRIDES.get(target, target)
                meta_out = meta_target(*args_metas, **kwargs_metas)
798
799
                if isinstance(meta_out, torch.Tensor):
                    meta_out = meta_out.to(device="meta")
Michael Benayoun's avatar
Michael Benayoun committed
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
            elif kind == "call_method":
                method = getattr(args_metas[0].__class__, target)
                meta_target = _MANUAL_META_OVERRIDES.get(method, method)
                meta_out = meta_target(*args_metas, **kwargs_metas)
            elif kind == "call_module":
                if not hasattr(self, "orig_forward"):
                    raise AttributeError(f"{self} does not have an attribute called orig_forward")
                self._disable_module_getattr = True
                try:
                    mod = self.root.get_submodule(target)
                    mod_type = type(mod)
                    if mod_type in _MANUAL_META_OVERRIDES:
                        meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
                    else:
                        meta_out = self.orig_forward(*args_metas, **kwargs_metas)
                finally:
                    self._disable_module_getattr = False
            elif kind == "get_attr":
                self._disable_module_getattr = True
                try:
                    attr_itr = self.root
                    atoms = target.split(".")
                    for atom in atoms:
                        attr_itr = getattr(attr_itr, atom)
                    if isinstance(attr_itr, torch.Tensor):
                        meta_out = attr_itr.to(device="meta")
                    else:
                        meta_out = attr_itr
                finally:
                    self._disable_module_getattr = False
            else:
                return rv
832

Michael Benayoun's avatar
Michael Benayoun committed
833
834
835
836
837
            if not isinstance(rv, Proxy):
                raise ValueError("Don't support composite output yet")
            rv.install_metadata(meta_out)
        except Exception as e:
            warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
838

Michael Benayoun's avatar
Michael Benayoun committed
839
        return rv
840

Michael Benayoun's avatar
Michael Benayoun committed
841
842
843
844
    def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
        if getattr(self, "_disable_module_getattr", False):
            return attr_val
        else:
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
            # return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
            def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
                for n, p in collection_to_search:
                    if attr_val is p:
                        if n not in parameter_proxy_cache:
                            kwargs = {}
                            if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
                                kwargs["proxy_factory_fn"] = (
                                    None
                                    if not self.param_shapes_constant
                                    else lambda node: ParameterProxy(self, node, n, attr_val)
                                )
                            val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs)  # type: ignore[arg-type]
                            parameter_proxy_cache[n] = val_proxy
                        return parameter_proxy_cache[n]
                return None

            if isinstance(attr_val, torch.nn.Parameter):
                maybe_parameter_proxy = maybe_get_proxy_for_attr(
                    attr_val, self.root.named_parameters(), parameter_proxy_cache
                )
                if maybe_parameter_proxy is not None:
                    return maybe_parameter_proxy

            if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
                maybe_buffer_proxy = maybe_get_proxy_for_attr(
                    attr_val, self.root.named_buffers(), parameter_proxy_cache
                )
                if maybe_buffer_proxy is not None:
                    return maybe_buffer_proxy

            return attr_val
877

Michael Benayoun's avatar
Michael Benayoun committed
878
879
880
    def call_module(self, m, forward, args, kwargs):
        self.orig_forward = forward
        return super().call_module(m, forward, args, kwargs)
881

Michael Benayoun's avatar
Michael Benayoun committed
882
883
    def proxy(self, node):
        return HFProxy(node, self)
884

885
    def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
886
887
888
        if concrete_args is None:
            concrete_args = {}

889
890
891
        sig = inspect.signature(root.forward)
        input_names = sig.parameters.keys() - concrete_args.keys()

Michael Benayoun's avatar
Michael Benayoun committed
892
893
894
895
        # Creating a random input shape to generate dummy inputs.
        batch_size = _generate_random_int()
        sequence_length = _generate_random_int()
        shape = [batch_size, sequence_length]
896

897
        if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
Michael Benayoun's avatar
Michael Benayoun committed
898
899
            num_choices = _generate_random_int(low=2, high=5)
            shape.insert(1, num_choices)
900

Michael Benayoun's avatar
Michael Benayoun committed
901
902
903
        inputs = {}
        for input_name in input_names:
            inputs.update(self._generate_dummy_input(root, input_name, shape))
904

Michael Benayoun's avatar
Michael Benayoun committed
905
        concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()}
906
907
908
        for param in sig.parameters.values():
            if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
                concrete_metas[f"**{param.name}"] = {}
Michael Benayoun's avatar
Michael Benayoun committed
909
910
911
912
913
        self.meta_args = concrete_metas
        self.patched_torch_methods = {
            target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
        }
        self.orig_fns = set()
914

Michael Benayoun's avatar
Michael Benayoun committed
915
916
917
        for name, (wrapper, orig) in self.patched_torch_methods.items():
            setattr(torch, name, wrapper)
            self.orig_fns.add(orig)
918

Michael Benayoun's avatar
Michael Benayoun committed
919
920
921
922
923
        try:
            self.graph = super().trace(root, concrete_args=concrete_args)
        finally:
            for name, (_, orig) in self.patched_torch_methods.items():
                setattr(torch, name, orig)
924

925
926
        # This is necessary because concrete args are added as input to the traced module since
        # https://github.com/pytorch/pytorch/pull/55888.
927
        for node in self.graph.nodes:
928
929
930
931
            if node.op == "placeholder":
                # Removing default values for inputs as the forward pass will fail with them.
                if node.target in input_names:
                    node.args = ()
932
933
934
                    # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
                    # It cannot infer on the attributes and methods the input should have, and fails.
                    node.type = torch.Tensor
935
936
                # It is a concrete arg so it is not used and should be removed.
                else:
937
938
939
940
941
942
943
944
945
                    to_visit = [node]
                    to_delete = collections.OrderedDict()
                    while to_visit:
                        n = to_visit.pop(0)
                        to_delete[n] = None
                        to_visit += list(n.users.keys())

                    for user in reversed(to_delete.keys()):
                        self.graph.erase_node(user)
946

947
948
949
950
951
            # TODO: solves GraphModule creation.
            # Without this, return type annotation "Tuple" is causing code execution failure.
            if node.op == "output":
                node.type = None

952
        return self.graph
953

Michael Benayoun's avatar
Michael Benayoun committed
954
955
956
957
958
959
960
    def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
        """
        Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module
        because its attributes are input-dependent.
        """
        return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())

961
    def _insert_module_as_submodule(self, mod: nn.Module) -> str:
962
963
964
        """
        Helper method which tries to insert a module that was not declared as submodule.
        """
Michael Benayoun's avatar
Michael Benayoun committed
965
966
967
968
        # If one of the module attributes is a Proxy, it means that its instantiation is input-dependent.
        # It is not possible to insert such modules, those should be traced through.
        if self._stateless_mod_instanciation_depends_on_proxies(mod):
            return ""
969
970
971
        idx = 0
        mod_name = mod.__class__.__name__.lower()
        path = f"{mod_name}_{idx}"
Michael Benayoun's avatar
Michael Benayoun committed
972
        already_inserted = False
973
        while hasattr(self.root, path):
Michael Benayoun's avatar
Michael Benayoun committed
974
975
976
            if getattr(self.root, path) is mod:
                already_inserted = True
                break
977
978
979
            path = f"{mod_name}_{idx}"
            idx += 1

Michael Benayoun's avatar
Michael Benayoun committed
980
981
982
        # No need to add multiple instances of the same module.
        if not already_inserted:
            self.root.add_module(path, mod)
983
984
        return path

985
    def path_of_module(self, mod: nn.Module) -> str:
986
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
987
988
989
        Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has
        a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the
        string "foo.bar".
990
991

        Args:
992
            mod (str): The `Module` to retrieve the qualified name for.
993
        """
Michael Benayoun's avatar
Michael Benayoun committed
994
995
996
997
        try:
            return super().path_of_module(mod)
        except NameError as e:
            if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
998
                path = self._insert_module_as_submodule(mod)
Michael Benayoun's avatar
Michael Benayoun committed
999
1000
                return path
            raise e
1001

Michael Benayoun's avatar
Michael Benayoun committed
1002
1003
1004
1005
    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
        return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module(
            m, module_qualified_name
        )
1006

1007

1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
def get_concrete_args(model: nn.Module, input_names: List[str]):
    sig = inspect.signature(model.forward)

    if not (set(input_names) <= set(sig.parameters.keys())):
        formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
        formatted_allowed_input_names = ", ".join(sig.parameters.keys())
        raise ValueError(
            f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
            f" {formatted_allowed_input_names}"
        )

    return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}


def check_if_model_is_supported(model: PreTrainedModel):
    if model.__class__.__name__ not in _SUPPORTED_MODELS:
        supported_model_names = ", ".join(_SUPPORTED_MODELS)
        raise NotImplementedError(
            f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
        )


1030
1031
1032
def symbolic_trace(
    model: PreTrainedModel,
    input_names: Optional[List[str]] = None,
1033
    disable_check: bool = False,
1034
1035
1036
1037
1038
1039
) -> GraphModule:

    """
    Performs symbolic tracing on the model.

    Args:
1040
        model ([`PretrainedModel`]):
1041
            The model to trace.
1042
        input_names (`List[str]`, *optional*):
1043
            The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
1044
1045
        disable_check (`bool`, *optional*, defaults to `False`):
            If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
1046
1047

    Returns:
1048
1049
1050
1051
        `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.

    Example:

1052
1053
        ```python
        from transformers.utils.fx import symbolic_trace
Sylvain Gugger's avatar
Sylvain Gugger committed
1054

1055
1056
1057
        traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
        ```
    """
1058
1059
1060
    if input_names is None:
        input_names = model.dummy_inputs.keys()

1061
    input_names = list(input_names)
1062
    concrete_args = get_concrete_args(model, input_names)
1063

1064
1065
    if not disable_check:
        check_if_model_is_supported(model)
1066
1067

    # Tracing.
1068
    tracer = HFTracer()
1069
1070
1071
    traced_graph = tracer.trace(model, concrete_args=concrete_args)
    traced = torch.fx.GraphModule(model, traced_graph)

1072
1073
1074
1075
1076
1077
    traced.config = model.config
    # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
    # _generate_dummy_input, where the model class is needed.
    traced.class_for_deserialization = model.__class__
    traced.device = model.device

1078
    return traced