functional_overrides.py 2.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

# TODO: think about the following two. They do weird things.
# - torch.nn.utils.clip_grad (but it should always be fp32 anyway)
# - torch.nn.utils.weight_norm

# Notes:
# F.instance_norm uses batch_norm internally. Which correctly handles
#   fp16 in/out with fp32 weights. So we shouldn't do anything for
#   either of these.
# F.normalize calls `input.norm()` internally, so it's redundant, but
#   kept here in case impl. changes.
# F.cosine_similarity is same: calls `x.norm()` internally.

import torch.nn.functional

MODULE = torch.nn.functional

FP16_FUNCS = [
    'conv1d',
    'conv2d',
    'conv3d',
    'conv_transpose1d',
    'conv_transpose2d',
    'conv_transpose3d',
    'conv_tbc', # Undocumented / maybe new?
    'linear',
]

FP32_FUNCS = [
30
31
32
33

    # Interpolation/Upsampling
    'interpolate',

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    # Pointwise
    'softplus',
    'softmin',
    'log_softmax',
    'softmax',

    # Normalization
    'layer_norm',
    'group_norm',
    'local_response_norm',
    'normalize',
    'cosine_similarity',

    # Loss functions
    # TODO: which of these can be fp16?
    'poisson_nll_loss',
    'cosine_embedding_loss',
    'cross_entropy',
    'hinge_embedding_loss',
    'kl_div',
    'l1_loss',
    'mse_loss',
    'margin_ranking_loss',
    'multilabel_margin_loss',
    'multilabel_soft_margin_loss',
    'multi_margin_loss',
    'nll_loss',
    'binary_cross_entropy_with_logits',
    'smooth_l1_loss',
    'soft_margin_loss',
    'triplet_margin_loss'
]
Carl Case's avatar
Carl Case committed
66
67
68
69
70
71
72
73
74
75
76
77

BANNED_FUNCS = [
    ('binary_cross_entropy',
     ("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
      "It requires that the output of the previous function be already a FloatTensor. \n\n"
      "Most models have a Sigmoid right before BCELoss. In that case, you can use\n"
      "    torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer "
      "that is compatible with amp.\nAnother option is to add\n"
      "    amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n"
      "If you _really_ know what you are doing, you can disable this warning by passing "
      "allow_banned=True to `amp.init()`."))
]