profiler.py 6.42 KB
Newer Older
1
from typing import Callable, Any, Dict, Tuple
2
import torch
3
from torch.fx import Graph
4
from torch.fx.node import Argument, Target
5
6
7
8
from torch.utils._pytree import tree_map
from .memory import activation_size, INPLACE_ATEN, WEIRD_OPS
from .tensor import MetaTensor
from .opcount import flop_mapping
9

10
11
12
13
14
15
16
__all__ = ['profile_function', 'profile_module', 'profile_method', '_profile']


def normalize_tuple(x):
    if not isinstance(x, tuple):
        return (x,)
    return x
17

18

19
20
def is_autogradable(x):
    return isinstance(x, torch.Tensor) and x.is_floating_point()
21

22
23
24

def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
    """Profile a Callable function with args and kwargs.
25
26

    Args:
27
28
29
        target (Callable): A Callable function
        args (Any): Argument
        kwargs (Any): Argument
30
31

    Returns:
32
33
34
        out (Tuple[Any, ...]): The argument value that was retrieved
        flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop).
        mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
35
    """
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    flop_count = {
        'f': 0,
        'l': 0,
        'b': 0,
    }
    temp = {
        'f': [],
        'l': [],
        'b': [],
    }
    stage = 'f'

    class FlopTensor(MetaTensor):

        def __repr__(self):
            if self.grad_fn:
                return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
            return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})"

        @classmethod
        def __torch_dispatch__(cls, func, types, args=(), kwargs=None):

            def unwrap(x):
                if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
                    x = FlopTensor(x.to('meta'))
                return x._tensor.to('meta') if isinstance(x, FlopTensor) else x

            def to_meta(x):
                return x.to('meta') if isinstance(x, torch.Tensor) else x

            args = tree_map(unwrap, args)
            kwargs = tree_map(unwrap, kwargs)

            # run aten for backend=CPU but actually on backend=Meta
            out = func(*args, **kwargs)
            flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
            if func not in INPLACE_ATEN:
                temp[stage].append(tree_map(to_meta, normalize_tuple(out)))

            def wrap(x):
                return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x

            return tree_map(wrap, out)

    if target not in WEIRD_OPS:

        def wrap(x):
            return FlopTensor(
                x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
    else:

        def wrap(x):
            return FlopTensor(
                x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x

    args = tree_map(wrap, args)
    kwargs = tree_map(wrap, kwargs)

    if isinstance(target, str):
        # args[0] is the `self` object for this method call
        self_obj, *args_tail = args
        out = getattr(self_obj, target)(*args_tail, **kwargs)
    else:
        out = target(*args, **kwargs)

    if is_autogradable(out) and out.requires_grad:
        stage = 'l'
        loss = out.sum()
        stage = 'b'
        loss.backward()

    fwd_flop = flop_count['f']
    bwd_flop = flop_count['b']

    fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0
    fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0
    bwd_tmp = max(map(activation_size, temp['b'])) if len(temp['b']) else 0

    def unwrap(x):
        return x._tensor.to('meta') if isinstance(x, FlopTensor) else x

    return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, 0)
119
120
121
122
123
124
125
126
127
128
129


def profile_function(target: 'Target') -> Callable:
    """
    Wrap a `call_function` node or `torch.nn.functional` in order to 
    record the memory cost and FLOPs of the execution.
    
    Warnings:
        You may only use tensors with `device=meta` for this wrapped function.
        Only original `torch.nn.functional` are available.
    
130
    Examples:
131
132
133
        >>> input = torch.rand(100, 100, 100, 100, device='meta')
        >>> func = torch.nn.functional.relu
        >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
134
135
136
    """

    def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
137
138
139
140
141
142
143
        if kwargs.get('inplace', False):
            args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
            kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
            out = func(*args, **kwargs)
            return out, (0, 0), (0, 0, 0, 0)
        out, flop_count, mem_stat = _profile(func, *args, **kwargs)
        return out, flop_count, mem_stat
144
145

    f.__name__ = target.__name__
146
    func = target
147
148
149
150
151
152
153
154
155
156
    return f


def profile_method(target: 'Target') -> Callable:
    """
    Wrap a `call_method` node
    record the memory cost and FLOPs of the execution. 
    """

    def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
157
        # execute the method and return the result
158
        assert isinstance(target, str), f'{target} instance is not str.'
159
160
        out, flop_count, mem_stat = _profile(target, *args, **kwargs)
        return out, flop_count, mem_stat
161
162
163
164
165
166
167
168
169
170
171
172
173

    return f


def profile_module(module: torch.nn.Module) -> Callable:
    """
    Wrap a `call_module` node or `torch.nn` in order to 
    record the memory cost and FLOPs of the execution.
    
    Warnings:
        You may only use tensors with `device=meta` for this wrapped function.
        Only original `torch.nn` are available.
    
174
    Example:
175
176
177
        >>> input = torch.rand(4, 3, 224, 224, device='meta')
        >>> mod = torch.nn.Conv2d(3, 128, 3)
        >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
178
179
180
    """

    def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
181
182
183
184
185
186
187
        if getattr(module, 'inplace', False):
            args = tree_map(lambda x: x.to('meta'), args)
            kwargs = tree_map(lambda x: x.to('meta'), kwargs)
            out = func(*args, **kwargs)
            return out, (out.numel(), out.numel()), (0, 0, 0, 0)
        out, flop_count, mem_stat = _profile(func, *args, **kwargs)
        return out, flop_count, mem_stat
188
189

    f.__name__ = module.__class__.__name__
190
    func = module.forward
191
    return f