utils.py 6.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from . import compat

import functools
import itertools

import torch

def is_fp_tensor(x):
    if is_nested(x):
        # Fast-fail version of all(is_fp_tensor)
        for y in x:
            if not is_fp_tensor(y):
                return False
        return True
    return compat.is_tensor_like(x) and compat.is_floating_point(x)

def is_nested(x):
    return isinstance(x, tuple) or isinstance(x, list)

def should_cache(x):
    if is_nested(x):
        # Fast-fail version of all(should_cache)
        for y in x:
            if not should_cache(y):
                return False
        return True
    return isinstance(x, torch.nn.parameter.Parameter) and \
        type_string(x) == 'FloatTensor'

def collect_fp_tensor_types(args, kwargs):
    def collect_types(x, types):
        if is_nested(x):
            for y in x:
                collect_types(y, types)
        else:
            types.add(type_string(x))

    all_args = itertools.chain(args, kwargs.values())
    types = set()
    for x in all_args:
        if is_fp_tensor(x):
            collect_types(x, types)
    return types

def type_string(x):
    return x.type().split('.')[-1]

def maybe_half(x, name='', verbose=False):
    if is_nested(x):
        return type(x)([maybe_half(y) for y in x])

52
    if not x.is_cuda or type_string(x) == 'HalfTensor':
53
54
55
56
57
58
59
60
61
62
        return x
    else:
        if verbose:
            print('Float->Half ({})'.format(name))
        return x.half()

def maybe_float(x, name='', verbose=False):
    if is_nested(x):
        return type(x)([maybe_float(y) for y in x])

63
    if not x.is_cuda or type_string(x) == 'FloatTensor':
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        return x
    else:
        if verbose:
            print('Half->Float ({})'.format(name))
        return x.float()

# NB: returneds casted `args`, mutates `kwargs` in-place
def casted_args(cast_fn, args, kwargs):
    new_args = []
    for x in args:
        if is_fp_tensor(x):
            new_args.append(cast_fn(x))
        else:
            new_args.append(x)
    for k in kwargs:
        val = kwargs[k]
        if is_fp_tensor(val):
            kwargs[k] = cast_fn(val)
    return new_args

def cached_cast(cast_fn, x, cache):
85
    # print("Calling cached_cast")
86
87
88
    if is_nested(x):
        return type(x)([cached_cast(y) for y in x])
    if x in cache:
89
        cached_x = cache[x]
Michael Carilli's avatar
Michael Carilli committed
90
91
92
93
94
        if x.requires_grad and cached_x.requires_grad:
            # Check to make sure x is actually cached_x's autograd parent.
            if cached_x.grad_fn.next_functions[1][0].variable is not x:
                raise RuntimeError("x and cache[x] both require grad, but x is not "
                                   "cache[x]'s parent.  This is likely an error.")
95
        if x.requires_grad != cached_x.requires_grad:
Michael Carilli's avatar
Michael Carilli committed
96
97
98
            del cache[x]
        else:
            return cached_x
99

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    casted_x = cast_fn(x)
    cache[x] = casted_x
    return casted_x

def verbosify(cast_fn, fn_name, verbose):
    if verbose:
        return functools.partial(cast_fn, name=fn_name, verbose=verbose)
    else:
        return cast_fn

def as_inplace(fns):
    for x in fns:
        yield x + '_'

def has_func(mod, fn):
    if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
        return fn in mod.function_classes
Carl Case's avatar
Carl Case committed
117
118
    elif isinstance(mod, dict):
        return fn in mod
119
120
121
122
123
124
    else:
        return hasattr(mod, fn)

def get_func(mod, fn):
    if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
        return mod.function_classes[fn]
Carl Case's avatar
Carl Case committed
125
126
    elif isinstance(mod, dict):
        return mod[fn]
127
128
129
130
131
132
    else:
        return getattr(mod, fn)

def set_func(mod, fn, new_fn):
    if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
        mod.function_classes[fn] = new_fn
Carl Case's avatar
Carl Case committed
133
134
    elif isinstance(mod, dict):
        mod[fn] = new_fn
135
136
137
    else:
        setattr(mod, fn, new_fn)

138
139
140
141
142
def set_func_save(handle, mod, fn, new_fn):
    cur_fn = get_func(mod, fn)
    handle._save_func(mod, fn, cur_fn)
    set_func(mod, fn, new_fn)

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# A couple problems get solved here:
# - The flat_weight buffer is disconnected from autograd graph,
#   so the fp16 weights need to be derived from the input weights
#   to this forward call, not the flat buffer.
# - The ordering of weights in the flat buffer is...idiosyncratic.
# First problem is solved with combination of set_ (to set up
# correct storage) and copy_ (so the fp16 weight derives from the
# fp32 one in autograd.
# Second is solved by doing ptr arithmetic on the fp32 weights
# to derive the correct offset.
#
# TODO: maybe this should actually use
# `torch._cudnn_rnn_flatten_weight`? But then I need to call
# on first iter and cache the right offsets. Ugh.
def synthesize_flattened_rnn_weights(fp32_weights,
                                     fp16_flat_tensor,
                                     rnn_fn='',
                                     verbose=False):
    fp16_weights = []
    fp32_base_ptr = fp32_weights[0][0].data_ptr()
    for layer_weights in fp32_weights:
        fp16_layer_weights = []
        for w_fp32 in layer_weights:
            w_fp16 = w_fp32.new().half()
            offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
            w_fp16.set_(fp16_flat_tensor.storage(),
                        offset,
                        w_fp32.shape)
            w_fp16.copy_(w_fp32)
            if verbose:
                print('Float->Half ({})'.format(rnn_fn))
            fp16_layer_weights.append(w_fp16)
        fp16_weights.append(fp16_layer_weights)
    return fp16_weights
Carl Case's avatar
Carl Case committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
def new_synthesize_flattened_rnn_weights(fp32_weights,
                                         fp16_flat_tensor,
                                         rnn_fn='',
                                         verbose=False):
    fp16_weights = []
    fp32_base_ptr = fp32_weights[0].data_ptr()
    for w_fp32 in fp32_weights:
        w_fp16 = w_fp32.new().half()
        offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
        w_fp16.set_(fp16_flat_tensor.storage(),
                    offset,
                    w_fp32.shape)
        w_fp16.copy_(w_fp32)
        if verbose:
            print('Float->Half ({})'.format(rnn_fn))
        fp16_weights.append(w_fp16)
    return fp16_weights