utils.py 7.89 KB
Newer Older
1
2
3
4
5
6
7
from . import compat

import functools
import itertools

import torch

8
9
10
def is_cuda_enabled():
    return torch.version.cuda is not None

11
12
13
def get_cuda_version():
    return tuple(int(x) for x in torch.version.cuda.split('.'))

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def is_fp_tensor(x):
    if is_nested(x):
        # Fast-fail version of all(is_fp_tensor)
        for y in x:
            if not is_fp_tensor(y):
                return False
        return True
    return compat.is_tensor_like(x) and compat.is_floating_point(x)

def is_nested(x):
    return isinstance(x, tuple) or isinstance(x, list)

def should_cache(x):
    if is_nested(x):
        # Fast-fail version of all(should_cache)
        for y in x:
            if not should_cache(y):
                return False
        return True
    return isinstance(x, torch.nn.parameter.Parameter) and \
        type_string(x) == 'FloatTensor'

def collect_fp_tensor_types(args, kwargs):
    def collect_types(x, types):
        if is_nested(x):
            for y in x:
                collect_types(y, types)
        else:
            types.add(type_string(x))

    all_args = itertools.chain(args, kwargs.values())
    types = set()
    for x in all_args:
        if is_fp_tensor(x):
            collect_types(x, types)
    return types

def type_string(x):
    return x.type().split('.')[-1]

def maybe_half(x, name='', verbose=False):
    if is_nested(x):
        return type(x)([maybe_half(y) for y in x])

58
    if not x.is_cuda or type_string(x) == 'HalfTensor':
59
60
61
62
63
64
        return x
    else:
        if verbose:
            print('Float->Half ({})'.format(name))
        return x.half()

65
66
67
68
69
70
71
72
73
74
75
def maybe_bfloat16(x, name='', verbose=False):
    if is_nested(x):
        return type(x)([maybe_bfloat16(y) for y in x])

    if not x.is_cuda or type_string(x) == 'BFloat16Tensor':
        return x
    else:
        if verbose:
            print('Float->BFloat16 ({})'.format(name))
        return x.bfloat16()

76
77
78
79
def maybe_float(x, name='', verbose=False):
    if is_nested(x):
        return type(x)([maybe_float(y) for y in x])

80
    if not x.is_cuda or type_string(x) == 'FloatTensor':
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        return x
    else:
        if verbose:
            print('Half->Float ({})'.format(name))
        return x.float()

# NB: returneds casted `args`, mutates `kwargs` in-place
def casted_args(cast_fn, args, kwargs):
    new_args = []
    for x in args:
        if is_fp_tensor(x):
            new_args.append(cast_fn(x))
        else:
            new_args.append(x)
    for k in kwargs:
        val = kwargs[k]
        if is_fp_tensor(val):
            kwargs[k] = cast_fn(val)
    return new_args

def cached_cast(cast_fn, x, cache):
    if is_nested(x):
        return type(x)([cached_cast(y) for y in x])
    if x in cache:
105
        cached_x = cache[x]
Hubert Lu's avatar
Hubert Lu committed
106
        next_functions_available = False
Michael Carilli's avatar
Michael Carilli committed
107
        if x.requires_grad and cached_x.requires_grad:
Hubert Lu's avatar
Hubert Lu committed
108
109
            if len(cached_x.grad_fn.next_functions) > 1:
                next_functions_available = True
Michael Carilli's avatar
Michael Carilli committed
110
            # Make sure x is actually cached_x's autograd parent.
Hubert Lu's avatar
Hubert Lu committed
111
            if next_functions_available and cached_x.grad_fn.next_functions[1][0].variable is not x:
Michael Carilli's avatar
Michael Carilli committed
112
113
                raise RuntimeError("x and cache[x] both require grad, but x is not "
                                   "cache[x]'s parent.  This is likely an error.")
Michael Carilli's avatar
Michael Carilli committed
114
115
116
117
118
        # During eval, it's possible to end up caching casted weights with
        # requires_grad=False.  On the next training iter, if cached_x is found
        # and reused from the cache, it will not actually have x as its parent.
        # Therefore, we choose to invalidate the cache (and force refreshing the cast)
        # if x.requires_grad and cached_x.requires_grad do not match.
119
120
121
122
123
124
125
126
127
128
        #
        # During eval (i.e. running under with torch.no_grad()) the invalidation
        # check would cause the cached value to be dropped every time, because
        # cached_x would always be created with requires_grad=False, while x would
        # still have requires_grad=True.  This would render the cache effectively
        # useless during eval.  Therefore, if we are running under the no_grad()
        # context manager (torch.is_grad_enabled=False) we elide the invalidation
        # check, and use the cached value even though its requires_grad flag doesn't
        # match.  During eval, we don't care that there's no autograd-graph
        # connection between x and cached_x.
129
        if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
Michael Carilli's avatar
Michael Carilli committed
130
            del cache[x]
Hubert Lu's avatar
Hubert Lu committed
131
132
        elif x.requires_grad and cached_x.requires_grad and not next_functions_available:
            del cache[x]
Michael Carilli's avatar
Michael Carilli committed
133
134
        else:
            return cached_x
135

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    casted_x = cast_fn(x)
    cache[x] = casted_x
    return casted_x

def verbosify(cast_fn, fn_name, verbose):
    if verbose:
        return functools.partial(cast_fn, name=fn_name, verbose=verbose)
    else:
        return cast_fn

def as_inplace(fns):
    for x in fns:
        yield x + '_'

def has_func(mod, fn):
151
    if isinstance(mod, dict):
Carl Case's avatar
Carl Case committed
152
        return fn in mod
153
154
155
156
    else:
        return hasattr(mod, fn)

def get_func(mod, fn):
157
    if isinstance(mod, dict):
Carl Case's avatar
Carl Case committed
158
        return mod[fn]
159
160
161
162
    else:
        return getattr(mod, fn)

def set_func(mod, fn, new_fn):
163
    if isinstance(mod, dict):
Carl Case's avatar
Carl Case committed
164
        mod[fn] = new_fn
165
166
167
    else:
        setattr(mod, fn, new_fn)

168
169
170
171
172
def set_func_save(handle, mod, fn, new_fn):
    cur_fn = get_func(mod, fn)
    handle._save_func(mod, fn, cur_fn)
    set_func(mod, fn, new_fn)

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# A couple problems get solved here:
# - The flat_weight buffer is disconnected from autograd graph,
#   so the fp16 weights need to be derived from the input weights
#   to this forward call, not the flat buffer.
# - The ordering of weights in the flat buffer is...idiosyncratic.
# First problem is solved with combination of set_ (to set up
# correct storage) and copy_ (so the fp16 weight derives from the
# fp32 one in autograd.
# Second is solved by doing ptr arithmetic on the fp32 weights
# to derive the correct offset.
#
# TODO: maybe this should actually use
# `torch._cudnn_rnn_flatten_weight`? But then I need to call
# on first iter and cache the right offsets. Ugh.
def synthesize_flattened_rnn_weights(fp32_weights,
                                     fp16_flat_tensor,
                                     rnn_fn='',
                                     verbose=False):
    fp16_weights = []
    fp32_base_ptr = fp32_weights[0][0].data_ptr()
    for layer_weights in fp32_weights:
        fp16_layer_weights = []
        for w_fp32 in layer_weights:
            w_fp16 = w_fp32.new().half()
            offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
            w_fp16.set_(fp16_flat_tensor.storage(),
                        offset,
                        w_fp32.shape)
            w_fp16.copy_(w_fp32)
            if verbose:
                print('Float->Half ({})'.format(rnn_fn))
            fp16_layer_weights.append(w_fp16)
        fp16_weights.append(fp16_layer_weights)
    return fp16_weights
Carl Case's avatar
Carl Case committed
207

208
209
210
211
212
def _str_from_dtype(dtype=torch.float16):
    type_to_str = {torch.float16 : 'Half',
                   torch.bfloat16 : 'BFloat16'}
    return type_to_str[dtype]

Carl Case's avatar
Carl Case committed
213
214
215
216
217
# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
def new_synthesize_flattened_rnn_weights(fp32_weights,
                                         fp16_flat_tensor,
                                         rnn_fn='',
218
                                         dtype=torch.float16,
Carl Case's avatar
Carl Case committed
219
220
221
222
                                         verbose=False):
    fp16_weights = []
    fp32_base_ptr = fp32_weights[0].data_ptr()
    for w_fp32 in fp32_weights:
223
        w_fp16 = w_fp32.new().to(dtype=dtype)
Carl Case's avatar
Carl Case committed
224
225
226
227
228
229
        offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
        w_fp16.set_(fp16_flat_tensor.storage(),
                    offset,
                    w_fp32.shape)
        w_fp16.copy_(w_fp32)
        if verbose:
230
            print('Float->{} ({})'.format(_str_from_dtype(dtype), rnn_fn))
Carl Case's avatar
Carl Case committed
231
232
        fp16_weights.append(w_fp16)
    return fp16_weights