"model_cards/indobenchmark/indobert-large-p2/README.md" did not exist on "1313a1d2a833ecbc7f37d2a855717e4b693e7538"
benchmark.py 5.77 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2022, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
3
4
5
6
7
""" Useful functions for writing test code. """

import torch
import torch.utils.benchmark as benchmark


Tri Dao's avatar
Tri Dao committed
8
9
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False,
                      amp_dtype=torch.float16, **kwinputs):
Tri Dao's avatar
Tri Dao committed
10
11
12
    """ Use Pytorch Benchmark on the forward pass of an arbitrary function. """
    if verbose:
        print(desc, '- Forward pass')
Tri Dao's avatar
Tri Dao committed
13
14
15
16
17
    def fn_amp(*inputs, **kwinputs):
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            fn(*inputs, **kwinputs)
    for _ in range(repeats):  # warmup
        fn_amp(*inputs, **kwinputs)
Tri Dao's avatar
Tri Dao committed
18
    t = benchmark.Timer(
Tri Dao's avatar
Tri Dao committed
19
20
            stmt='fn_amp(*inputs, **kwinputs)',
            globals={'fn_amp': fn_amp, 'inputs': inputs, 'kwinputs': kwinputs},
Tri Dao's avatar
Tri Dao committed
21
22
23
24
25
26
27
28
            num_threads=torch.get_num_threads(),
            )
    m = t.timeit(repeats)
    if verbose:
        print(m)
    return t, m


Tri Dao's avatar
Tri Dao committed
29
30
def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
                       amp_dtype=torch.float16, **kwinputs):
Tri Dao's avatar
Tri Dao committed
31
32
33
    """ Use Pytorch Benchmark on the backward pass of an arbitrary function. """
    if verbose:
        print(desc, '- Backward pass')
Tri Dao's avatar
Tri Dao committed
34
35
36
37
    with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
        y = fn(*inputs, **kwinputs)
        if type(y) is tuple:
            y = y[0]
Tri Dao's avatar
Tri Dao committed
38
39
40
41
42
    if grad is None:
        grad = torch.randn_like(y)
    else:
        if grad.shape != y.shape:
            raise RuntimeError('Grad shape does not match output shape')
Tri Dao's avatar
Tri Dao committed
43
44
    for _ in range(repeats):  # warmup
        y.backward(grad, retain_graph=True)
Tri Dao's avatar
Tri Dao committed
45
46
47
48
    t = benchmark.Timer(
            stmt='y.backward(grad, retain_graph=True)',
            globals={'y': y, 'grad': grad},
            num_threads=torch.get_num_threads(),
Tri Dao's avatar
Tri Dao committed
49
        )
Tri Dao's avatar
Tri Dao committed
50
51
52
53
54
55
    m = t.timeit(repeats)
    if verbose:
        print(m)
    return t, m


Tri Dao's avatar
Tri Dao committed
56
57
def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
                       amp_dtype=torch.float16, **kwinputs):
Tri Dao's avatar
Tri Dao committed
58
59
60
61
    """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
    if verbose:
        print(desc, '- Forward + Backward pass')
    def f(grad, *inputs, **kwinputs):
Tri Dao's avatar
Tri Dao committed
62
63
64
65
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            y = fn(*inputs, **kwinputs)
            if type(y) is tuple:
                y = y[0]
Tri Dao's avatar
Tri Dao committed
66
67
68
69
70
71
        if grad is None:
            grad = torch.randn_like(y)
        else:
            if grad.shape != y.shape:
                raise RuntimeError('Grad shape does not match output shape')
        y.backward(grad, retain_graph=True)
Tri Dao's avatar
Tri Dao committed
72
73
    for _ in range(repeats):  # warmup
        f(grad, *inputs, **kwinputs)
Tri Dao's avatar
Tri Dao committed
74
75
76
77
78
79
80
81
82
83
84
    t = benchmark.Timer(
            stmt='f(grad, *inputs, **kwinputs)',
            globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs},
            num_threads=torch.get_num_threads(),
            )
    m = t.timeit(repeats)
    if verbose:
        print(m)
    return t, m


Tri Dao's avatar
Tri Dao committed
85
86
def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
                  amp_dtype=torch.float16, **kwinputs):
Tri Dao's avatar
Tri Dao committed
87
88
    """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
    return (
Tri Dao's avatar
Tri Dao committed
89
90
        benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
                          amp=amp, amp_dtype=amp_dtype, **kwinputs),
Tri Dao's avatar
Tri Dao committed
91
        benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
Tri Dao's avatar
Tri Dao committed
92
                           amp=amp, amp_dtype=amp_dtype, **kwinputs),
Tri Dao's avatar
Tri Dao committed
93
        benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
Tri Dao's avatar
Tri Dao committed
94
                           amp=amp, amp_dtype=amp_dtype, **kwinputs),
Tri Dao's avatar
Tri Dao committed
95
96
97
    )


Tri Dao's avatar
Tri Dao committed
98
99
def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False,
                     amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs):
Tri Dao's avatar
Tri Dao committed
100
    """ Wrap benchmark functions in Pytorch profiler to see CUDA information. """
Tri Dao's avatar
Tri Dao committed
101
    if backward:
Tri Dao's avatar
Tri Dao committed
102
103
104
105
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
            g = torch.randn_like(fn(*inputs, **kwinputs))
    for _ in range(30):   # Warm up
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
Tri Dao's avatar
Tri Dao committed
106
107
108
109
            if backward:
                for x in inputs:
                    if isinstance(x, torch.Tensor):
                        x.grad = None
Tri Dao's avatar
Tri Dao committed
110
111
112
113
114
115
            # fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g)
            out = fn(*inputs, **kwinputs)
        # Backward should be done outside autocast
        if backward:
            out.backward(g)
    activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA]
Tri Dao's avatar
Tri Dao committed
116
    with torch.profiler.profile(
Tri Dao's avatar
Tri Dao committed
117
        activities=activities,
Tri Dao's avatar
Tri Dao committed
118
119
120
121
        record_shapes=True,
        # profile_memory=True,
        with_stack=True,
    ) as prof:
Tri Dao's avatar
Tri Dao committed
122
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
Tri Dao's avatar
Tri Dao committed
123
124
125
126
            if backward:
                for x in inputs:
                    if isinstance(x, torch.Tensor):
                        x.grad = None
Tri Dao's avatar
Tri Dao committed
127
128
            out = fn(*inputs, **kwinputs)
        if backward: out.backward(g)
Tri Dao's avatar
Tri Dao committed
129
    if verbose:
Tri Dao's avatar
Tri Dao committed
130
131
        # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
        print(prof.key_averages().table(row_limit=50))
Tri Dao's avatar
Tri Dao committed
132
133
    if trace_filename is not None:
        prof.export_chrome_trace(trace_filename)
Tri Dao's avatar
Tri Dao committed
134
135
136
137
138
139
140
141
142
143


def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    fn(*inputs, **kwinputs)
    torch.cuda.synchronize()
    mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000)
    if verbose:
Tri Dao's avatar
Tri Dao committed
144
        print(f'{desc} max memory: {mem}GB')
Tri Dao's avatar
Tri Dao committed
145
146
    torch.cuda.empty_cache()
    return mem