torch_overrides.py 1.92 KB
Newer Older
1
2
import torch

3
4
from .. import utils

5
6
7
MODULE = torch

FP16_FUNCS = [
henrymai's avatar
henrymai committed
8
9
10
    # Low level functions wrapped by torch.nn layers.
    # The wrapper layers contain the weights which are then passed in as a parameter
    # to these functions.
11
12
13
14
15
16
17
    'conv1d',
    'conv2d',
    'conv3d',
    'conv_transpose1d',
    'conv_transpose2d',
    'conv_transpose3d',
    'conv_tbc',
henrymai's avatar
henrymai committed
18
    'prelu',
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    # BLAS
    'addmm',
    'addmv',
    'addr',
    'matmul',
    'mm',
    'mv',
]

FP32_FUNCS = [
    # Pointwise
    'acos',
    'asin',
    'cosh',
    'erfinv',
    'exp',
    'expm1',
    'log',
    'log10',
    'log2',
    'reciprocal',
    'rsqrt',
    'sinh',
    'tan',

    # Other math
    'pow',

    # Reduction
    'cumprod',
    'cumsum',
    'dist',
52
    # 'mean',
53
54
55
56
57
58
59
60
61
62
    'norm',
    'prod',
    'std',
    'sum',
    'var',

    # Misc
    'renorm'
]

63
64
65
66
67
68
69
70
version_strings = torch.__version__.split('.')
version_major = version_strings[0]
version_minor = version_strings[1]
version_num = float(version_major + "." + version_minor)
# Before torch 1.1, mean must be blacklisted.
if version_num < 1.1:
    FP32_FUNCS.append('mean')

71
72
73
74
75
76
77
78
79
80
81
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list.
_bmms = ['addbmm',
         'baddbmm',
         'bmm']
if utils.get_cuda_version() >= (9, 1, 0):
    FP16_FUNCS.extend(_bmms)
else:
    FP32_FUNCS.extend(_bmms)

82
83
84
85
86
87
88
# Multi-tensor fns that may need type promotion
CASTS = [
    # Multi-tensor math
    'addcdiv',
    'addcmul',
    'atan2',
    'cross',
mcarilli's avatar
mcarilli committed
89
    'bilinear',
mcarilli's avatar
mcarilli committed
90
    'dot',
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

    # Element-wise _or_ tensor-wise math
    'add',
    'div',
    'mul',

    # Comparison
    'eq',
    'equal',
    'ge',
    'gt',
    'le',
    'lt',
    'ne'
]

107
108
# Functions that take sequence arguments. We need to inspect the whole
# sequence and cast to the widest type.
109
SEQUENCE_CASTS = [
110
111
    'cat',
    'stack'
112
]