benchmark.py 6.53 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
3
4
5
6
7
""" Useful functions for writing test code. """

import torch
import torch.utils.benchmark as benchmark


Tri Dao's avatar
Tri Dao committed
8
9
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False,
                      amp_dtype=torch.float16, **kwinputs):
Tri Dao's avatar
Tri Dao committed
10
11
12
    """ Use Pytorch Benchmark on the forward pass of an arbitrary function. """
    if verbose:
        print(desc, '- Forward pass')
13
    def amp_wrapper(*inputs, **kwinputs):
Tri Dao's avatar
Tri Dao committed
14
15
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            fn(*inputs, **kwinputs)
Tri Dao's avatar
Tri Dao committed
16
    t = benchmark.Timer(
Tri Dao's avatar
Tri Dao committed
17
            stmt='fn_amp(*inputs, **kwinputs)',
18
            globals={'fn_amp': amp_wrapper, 'inputs': inputs, 'kwinputs': kwinputs},
Tri Dao's avatar
Tri Dao committed
19
20
21
22
23
24
25
26
            num_threads=torch.get_num_threads(),
            )
    m = t.timeit(repeats)
    if verbose:
        print(m)
    return t, m


Tri Dao's avatar
Tri Dao committed
27
28
def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
                       amp_dtype=torch.float16, **kwinputs):
Tri Dao's avatar
Tri Dao committed
29
30
31
    """ Use Pytorch Benchmark on the backward pass of an arbitrary function. """
    if verbose:
        print(desc, '- Backward pass')
Tri Dao's avatar
Tri Dao committed
32
33
34
35
    with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
        y = fn(*inputs, **kwinputs)
        if type(y) is tuple:
            y = y[0]
Tri Dao's avatar
Tri Dao committed
36
37
38
39
40
    if grad is None:
        grad = torch.randn_like(y)
    else:
        if grad.shape != y.shape:
            raise RuntimeError('Grad shape does not match output shape')
41
42
43
44
45
    def f(*inputs, y, grad):
        # Set .grad to None to avoid extra operation of gradient accumulation
        for x in inputs:
            if isinstance(x, torch.Tensor):
                x.grad = None
Tri Dao's avatar
Tri Dao committed
46
        y.backward(grad, retain_graph=True)
47

Tri Dao's avatar
Tri Dao committed
48
    t = benchmark.Timer(
49
50
            stmt='f(*inputs, y=y, grad=grad)',
            globals={'f': f, 'inputs': inputs, 'y': y, 'grad': grad},
Tri Dao's avatar
Tri Dao committed
51
            num_threads=torch.get_num_threads(),
52
            )
Tri Dao's avatar
Tri Dao committed
53
54
55
56
57
58
    m = t.timeit(repeats)
    if verbose:
        print(m)
    return t, m


Tri Dao's avatar
Tri Dao committed
59
60
def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
                       amp_dtype=torch.float16, **kwinputs):
Tri Dao's avatar
Tri Dao committed
61
62
63
    """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
    if verbose:
        print(desc, '- Forward + Backward pass')
64
65
66
67
68
69
70
71
72
    with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
        y = fn(*inputs, **kwinputs)
        if type(y) is tuple:
            y = y[0]
    if grad is None:
        grad = torch.randn_like(y)
    else:
        if grad.shape != y.shape:
            raise RuntimeError('Grad shape does not match output shape')
Tri Dao's avatar
Tri Dao committed
73
    def f(grad, *inputs, **kwinputs):
74
75
76
        for x in inputs:
            if isinstance(x, torch.Tensor):
                x.grad = None
Tri Dao's avatar
Tri Dao committed
77
78
79
80
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            y = fn(*inputs, **kwinputs)
            if type(y) is tuple:
                y = y[0]
Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
86
87
88
89
90
91
92
        y.backward(grad, retain_graph=True)
    t = benchmark.Timer(
            stmt='f(grad, *inputs, **kwinputs)',
            globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs},
            num_threads=torch.get_num_threads(),
            )
    m = t.timeit(repeats)
    if verbose:
        print(m)
    return t, m


93
94
95
96
97
98
99
100
101
102
103
def benchmark_fwd_bwd(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
                      amp_dtype=torch.float16, **kwinputs):
    """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
    return (
        benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
                          amp=amp, amp_dtype=amp_dtype, **kwinputs),
        benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
                           amp=amp, amp_dtype=amp_dtype, **kwinputs),
    )


Tri Dao's avatar
Tri Dao committed
104
105
def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
                  amp_dtype=torch.float16, **kwinputs):
Tri Dao's avatar
Tri Dao committed
106
107
    """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
    return (
Tri Dao's avatar
Tri Dao committed
108
109
        benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
                          amp=amp, amp_dtype=amp_dtype, **kwinputs),
Tri Dao's avatar
Tri Dao committed
110
        benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
Tri Dao's avatar
Tri Dao committed
111
                           amp=amp, amp_dtype=amp_dtype, **kwinputs),
Tri Dao's avatar
Tri Dao committed
112
        benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
Tri Dao's avatar
Tri Dao committed
113
                           amp=amp, amp_dtype=amp_dtype, **kwinputs),
Tri Dao's avatar
Tri Dao committed
114
115
116
    )


Tri Dao's avatar
Tri Dao committed
117
118
def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False,
                     amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs):
Tri Dao's avatar
Tri Dao committed
119
    """ Wrap benchmark functions in Pytorch profiler to see CUDA information. """
Tri Dao's avatar
Tri Dao committed
120
    if backward:
Tri Dao's avatar
Tri Dao committed
121
122
123
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            g = torch.randn_like(fn(*inputs, **kwinputs))
    for _ in range(30):   # Warm up
124
125
126
127
        if backward:
            for x in inputs:
                if isinstance(x, torch.Tensor):
                    x.grad = None
Tri Dao's avatar
Tri Dao committed
128
129
130
131
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            out = fn(*inputs, **kwinputs)
        # Backward should be done outside autocast
        if backward:
132
            out.backward(g, retain_graph=True)
Tri Dao's avatar
Tri Dao committed
133
    activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA]
Tri Dao's avatar
Tri Dao committed
134
    with torch.profiler.profile(
Tri Dao's avatar
Tri Dao committed
135
        activities=activities,
Tri Dao's avatar
Tri Dao committed
136
137
138
139
        record_shapes=True,
        # profile_memory=True,
        with_stack=True,
    ) as prof:
140
141
142
143
        if backward:
            for x in inputs:
                if isinstance(x, torch.Tensor):
                    x.grad = None
Tri Dao's avatar
Tri Dao committed
144
145
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            out = fn(*inputs, **kwinputs)
146
        if backward: out.backward(g, retain_graph=True)
Tri Dao's avatar
Tri Dao committed
147
    if verbose:
Tri Dao's avatar
Tri Dao committed
148
149
        # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
        print(prof.key_averages().table(row_limit=50))
Tri Dao's avatar
Tri Dao committed
150
151
    if trace_filename is not None:
        prof.export_chrome_trace(trace_filename)
Tri Dao's avatar
Tri Dao committed
152
153
154
155
156
157
158
159
160
161


def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    fn(*inputs, **kwinputs)
    torch.cuda.synchronize()
    mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000)
    if verbose:
Tri Dao's avatar
Tri Dao committed
162
        print(f'{desc} max memory: {mem}GB')
Tri Dao's avatar
Tri Dao committed
163
164
    torch.cuda.empty_cache()
    return mem