profiler.py 15 KB
Newer Older
1
import time
2
from functools import partial
3
4
from typing import Any, Callable, Dict, Tuple

5
import torch
6
from torch.fx import Graph, Node
7
from torch.fx.node import Argument, Target
8
from torch.nn.parameter import Parameter
9
from torch.utils._pytree import tree_map
10
11

from .._compatibility import compatibility
12
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
13
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
14
from .memory_utils import activation_size, parameter_size
15
from .opcount import flop_mapping
16
from .tensor import MetaTensor
17

18
__all__ = ["profile_function", "profile_module", "profile_method"]
19

20
21
22
23
# super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes
cache = set()

24
25
26
# a global identifier for inplace ops
do_not_cache = False

27
28
29
30
31

def normalize_tuple(x):
    if not isinstance(x, tuple):
        return (x,)
    return x
32

33

34
35
def is_autogradable(x):
    return isinstance(x, torch.Tensor) and x.is_floating_point()
36

37

38
def detach_variables(x):
39
40
    if isinstance(x, torch.Tensor):
        requires_grad = x.requires_grad
41
42
43
44
        x = x.detach()
        x.requires_grad = requires_grad

    return x
45
46


47
@compatibility(is_backward_compatible=True)
48
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
49
50
51
52
53
54
55
56
    """Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
    To profile the actual forward memory, we first run target in the context torch.no_grad() to get
    the fwd_mem_out, then we run target with grad enable to found the extra memory stored in the memory
    by memory allocated minus the fwd_mem_out.
    To profile the actual backward memory, we first make dummy gradient for torch.autograd.backward, then
    find the bwd_mem_tmp with memory peak during the process minus bwd_mem_out(it is actually equal to size
    of args and kwargs).
    We also add time stamps to profile the real forward and backward time.
57
58
59

    Args:
        target (Callable): A Callable function
60
61
        args (Any): Arguments
        kwargs (Any): Arguments
62
63

    Returns:
64
65
        Tuple[Tuple[Any, ...], GraphInfo]: Output for next node & memory cost and real forward and backward
        time.
66
    """
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

    graphinfo = GraphInfo()

    # detach input from the graph
    args = tree_map(detach_variables, args)
    kwargs = tree_map(detach_variables, kwargs)
    if isinstance(target, str):
        # args[0] is the `self` object for this method call
        self_obj, *args_tail = args

        # calculate fwd_mem_out
        mem_stamp0 = torch.cuda.memory_allocated()
        with torch.no_grad():
            out = getattr(self_obj, target)(*args_tail, **kwargs)
        mem_stamp1 = torch.cuda.memory_allocated()
        graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0
        del out

        # calculate fwd_mem_tmp & fwd_time
        mem_stamp0 = torch.cuda.memory_allocated()
        fwd_time0 = time.time()
        out = getattr(self_obj, target)(*args_tail, **kwargs)
        fwd_time1 = time.time()
        graphinfo.fwd_time = fwd_time1 - fwd_time0
        mem_stamp1 = torch.cuda.memory_allocated()
        graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out

        # calculate bwd_mem_tmp & bwd_time
        grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out)
        torch.cuda.reset_peak_memory_stats()
        mem_stamp0 = torch.cuda.memory_allocated()
        bwd_time0 = time.time()
        torch.autograd.backward(out, grad_tensors=grad_tensors)
        bwd_time1 = time.time()
        graphinfo.bwd_time = bwd_time1 - bwd_time0
        mem_stamp1 = torch.cuda.max_memory_allocated()

        # calculate bwd memory stats
        # NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation
        graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs)
        graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0
        graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out

    else:
        # calculate fwd_mem_out
        mem_stamp0 = torch.cuda.memory_allocated()
        with torch.no_grad():
            out = target(*args, **kwargs)
        mem_stamp1 = torch.cuda.memory_allocated()
        graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0
        del out

        # calculate fwd_mem_tmp & fwd_time
        mem_stamp0 = torch.cuda.memory_allocated()
        fwd_time0 = time.time()
        out = target(*args, **kwargs)
        fwd_time1 = time.time()
        graphinfo.fwd_time = fwd_time1 - fwd_time0
        mem_stamp1 = torch.cuda.memory_allocated()
        graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out

        # calculate bwd_mem_tmp & bwd_time
        grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out)
        torch.cuda.reset_peak_memory_stats()
        mem_stamp0 = torch.cuda.memory_allocated()
        bwd_time0 = time.time()
        torch.autograd.backward(out, grad_tensors=grad_tensors)
        bwd_time1 = time.time()
        graphinfo.bwd_time = bwd_time1 - bwd_time0
        mem_stamp1 = torch.cuda.max_memory_allocated()

        # calculate bwd memory stats
        # NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation
        graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs)
        graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0
        graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out

    return tree_map(detach_variables, out), graphinfo
145
146


147
@compatibility(is_backward_compatible=False)
148
149
150
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
    """
    Profile a Callable function with args and kwargs on meta devices.
151
152

    Args:
153
154
155
        target (Callable): A Callable function
        args (Any): Argument
        kwargs (Any): Argument
156
157

    Returns:
158
159
        out (Tuple[Any, ...]): The argument value that was retrieved.
        meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`.
160
    """
161
162
    # This subgraph traces aten level ops inside one node.
    subgraph = Graph()
163

164
    # `flop_count`` serves as a global dictionary to store results.
165
    flop_count = {
166
167
        Phase.FORWARD: 0,
        Phase.BACKWARD: 0,
168
169
    }

170
171
172
173
174
175
    # FlopTensor not only get the flop statistics of a single node,
    # it also build a full autograd graph for this node.
    # This makes sure we can analyze the dependencies of memory, and
    # decide which forward intermediate results should be kept until
    # backward is executed.
    # Hopefully, this attempt will provide a better estimation of memory.
176
    class FlopTensor(MetaTensor):
177
        _node: Node = None
178

179
180
        def __repr__(self):
            if self.grad_fn:
181
182
                return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})"
            return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})"
183
184
185

        @classmethod
        def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
186
187
            args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
            kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
188
            node = subgraph.create_node("call_function", func, args_node, kwargs_node)
189

190
            out = super().__torch_dispatch__(func, types, args, kwargs)
191

192
            flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
193
            node.meta["phase"] = phase
194

195
196
197
198
199
            # super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
            # i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
            # `Phase.FORWARD`
            if phase == Phase.FORWARD:
                if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
200
                    node.meta["phase"] = Phase.PLACEHOLDER
201

202
            # TODO(yby): specify `saved_tensors` for backward memory estimation
203
            node.meta["saved_tensor"] = []
204
            if phase == Phase.BACKWARD:
205
                node.meta["saved_tensor"] = normalize_tuple(out)
206

207
            def wrap(x):
208
209
210
211
                if isinstance(x, MetaTensor):
                    x = FlopTensor(x)
                    x._node = node
                return x
212
213
214

            out = tree_map(wrap, out)
            return out
215

216
    def wrap(x):
217
218
219
220
        if isinstance(x, torch.Tensor):
            x = FlopTensor(x)
            if is_autogradable(x):
                x.requires_grad_(True)
221
222
223
224
225
226
227
228
            x._node = subgraph.create_node(
                "placeholder",
                "placeholder",
                (subgraph._root,),
                name=subgraph._graph_namespace.create_name("input", x._tensor),
            )
            x._node.meta["phase"] = Phase.PLACEHOLDER
            x._node.meta["saved_tensor"] = []
229
        return x
230

231
232
233
    # Basically, we need to detach the args and kwargs from the outer graph.
    args = tree_map(wrap, args)
    kwargs = tree_map(wrap, kwargs)
234
235

    def pack(x):
236
        global cache, do_not_cache
237
        if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
238
            tensor = x._tensor.detach()
239
            tensor.data_ptr = x._tensor.data_ptr
240
            x._node.meta["saved_tensor"] += [tensor]
241
            if not do_not_cache:
242
                cache.add(x._tensor.data_ptr())
243
244
245
246
247
        return x

    def unpack(x):
        return x

248
249
    # `phase` will mark the phase of autograd from outside scope.
    phase = Phase.FORWARD
250
251
252
253
254
255
256
257
258
    # mark saved tensors with saved_tensors_hooks
    with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
        if isinstance(target, str):
            # args[0] is the `self` object for this method call
            self_obj, *args_tail = args
            out = getattr(self_obj, target)(*args_tail, **kwargs)
        else:
            out = target(*args, **kwargs)

259
260
        # If the output is not a floating point `torch.Tensor` or it does not
        # requires grad, then we should not run backward for this node.
261
262
263
264
265
266
267
        if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))):
            grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)]
            phase = Phase.BACKWARD
            torch.autograd.backward(
                out,
                grad_out,
            )
268

269
    graph_info = autograd_graph_analysis(subgraph)
270
    graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
271
272
273
274

    def extract_tensor(x: Any):
        if isinstance(x, MetaTensor):
            tensor = x._tensor.detach()
275
            tensor.data_ptr = x._tensor.data_ptr
276
            return tensor
277
278
        if not isinstance(x, torch.finfo):
            return x
279
280

    graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out)))
281
282

    def unwrap(x):
283
        return MetaTensor(x) if isinstance(x, torch.Tensor) else x
284

285
    return tree_map(unwrap, out), graph_info
286
287


288
@compatibility(is_backward_compatible=True)
289
def profile_function(target: "Target", device: str = "meta") -> Callable:
290
    """
291
    Wrap a `call_function` node or `torch.nn.functional` in order to
292
    record the memory cost and FLOPs of the execution.
293

294
295
296
    Warnings:
        You may only use tensors with `device=meta` for this wrapped function.
        Only original `torch.nn.functional` are available.
297

298
    Examples:
299
300
        >>> input = torch.rand(100, 100, 100, 100, device='meta')
        >>> func = torch.nn.functional.relu
301
        >>> output, meta_info = profile_function(func)(input)
302
303
304
    """

    def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
305
306
307
308
309
310
311
312
313
314
315
        # find the grad for parameter in args and kwargs
        param_size = 0

        def get_param_size(x):
            nonlocal param_size
            if isinstance(x, Parameter):
                param_size += activation_size(x)

        tree_map(get_param_size, args)
        tree_map(get_param_size, kwargs)

316
        # If there is an argument that this `call_function` is inplace, we should
317
        # still run the profiling but discard some results regarding `target`
318
        global do_not_cache
319

320
        inplace = kwargs.get("inplace", False)
321
322
323
        if target in OUTPUT_SAVED_OPS:
            do_not_cache = True
        if inplace:
324
            do_not_cache = True
325
326
            kwargs["inplace"] = False
        if device == "meta":
327
            out, meta = _profile_meta(func, *args, **kwargs)
328
329
        else:
            out, meta = _profile_concrete(func, *args, **kwargs)
330
        if inplace:
331
            kwargs["inplace"] = True
332
333
            meta.bwd_mem_tmp = 0
            meta.bwd_mem_out = 0
334
        do_not_cache = False
335
336

        meta.bwd_mem_out -= param_size
337
        return out, meta
338
339

    f.__name__ = target.__name__
340
    func = target
341
342
343
    return f


344
@compatibility(is_backward_compatible=True)
345
def profile_method(target: "Target", device: str = "meta") -> Callable:
346
347
    """
    Wrap a `call_method` node
348
    record the memory cost and FLOPs of the execution.
349
350
351
    """

    def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
352
        # execute the method and return the result
353
354
        assert isinstance(target, str), f"{target} instance is not str."
        if device == "meta":
355
356
357
            out, meta = _profile_meta(target, *args, **kwargs)
        else:
            out, meta = _profile_concrete(target, *args, **kwargs)
358
        return out, meta
359
360
361
362

    return f


363
@compatibility(is_backward_compatible=True)
364
def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable:
365
    """
366
    Wrap a `call_module` node or `torch.nn` in order to
367
    record the memory cost and FLOPs of the execution.
368

369
370
371
    Warnings:
        You may only use tensors with `device=meta` for this wrapped function.
        Only original `torch.nn` are available.
372

373
    Example:
374
375
        >>> input = torch.rand(4, 3, 224, 224, device='meta')
        >>> mod = torch.nn.Conv2d(3, 128, 3)
376
        >>> output, meta_info = profile_module(mod)(input)
377
378
379
    """

    def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
380
381
382
        # calculate parameter size
        param_size = parameter_size(module)

383
384
385
386
        # If there is an argument that this `call_module` is inplace, we should
        # still run the profiling but discard some results regarding `module`.
        global do_not_cache

387
        inplace = getattr(module, "inplace", False)
388
389
390
        if type(module) in OUTPUT_SAVED_MOD:
            do_not_cache = True
        if inplace:
391
            do_not_cache = True
392
            module.inplace = False
393
        if device == "meta":
394
            out, meta = _profile_meta(func, *args, **kwargs)
395
396
        else:
            out, meta = _profile_concrete(func, *args, **kwargs)
397
398
        if inplace:
            module.inplace = True
399
400
            meta.bwd_mem_tmp = 0
            meta.bwd_mem_out = 0
401
        do_not_cache = False
402
403
404

        # grad for param will not be counted
        meta.bwd_mem_out -= param_size
405
        return out, meta
406
407

    f.__name__ = module.__class__.__name__
408
    func = module.forward
409
    return f