perf.py 7.55 KB
Newer Older
Shenggan's avatar
Shenggan committed
1
2
3
4
5
6
import argparse
import os

import torch
import torch.nn as nn

Shenggan's avatar
Shenggan committed
7
from fastfold.distributed import init_dap
8
from fastfold.model.fastnn import Evoformer
Shenggan's avatar
Shenggan committed
9
10
11
12


def main():

Shenggan's avatar
Shenggan committed
13
14
15
    parser = argparse.ArgumentParser(description='Evoformer Standalone Perf Benchmark')
    parser.add_argument("--dap-size", default=1, type=int, help='batch size')
    parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of MSA')
Shenggan's avatar
Shenggan committed
16
17
18
    parser.add_argument('--res-length',
                        default=256,
                        type=int,
Shenggan's avatar
Shenggan committed
19
                        help='Sequence Length of Residues')
Shenggan's avatar
Shenggan committed
20
21
22
23
24
    parser.add_argument('--trials', default=50, type=int, help='Number of Trials to Execute')
    parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
    parser.add_argument('--layers',
                        default=12,
                        type=int,
Shenggan's avatar
Shenggan committed
25
                        help='Evoformer Layers to Execute')
Shenggan's avatar
Shenggan committed
26
27
28
29
30
    parser.add_argument('--cm', default=256, type=int, help='MSA hidden dimension')
    parser.add_argument('--cz', default=128, type=int, help='Pair hidden dimension')
    parser.add_argument('--heads', default=8, type=int, help='Number of Multihead Attention heads')
    parser.add_argument('--openfold',
                        action='store_true',
Shenggan's avatar
Shenggan committed
31
                        help='Benchmark with Evoformer Implementation from OpenFold.')
Shenggan's avatar
Shenggan committed
32
    parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
Shenggan's avatar
Shenggan committed
33
    parser.add_argument('--prof', action='store_true', help='run with profiler.')
Shenggan's avatar
Shenggan committed
34
35
36

    args = parser.parse_args()

Shenggan's avatar
Shenggan committed
37
    init_dap(args.dap_size)
Shenggan's avatar
Shenggan committed
38
39

    precision = torch.bfloat16
Shenggan's avatar
Shenggan committed
40
    if args.dap_size > 1:
Shenggan's avatar
Shenggan committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        # (PyTorch issue) Currently All2All communication does not support the Bfloat16 datatype in PyTorch
        precision = torch.float16

    if not torch.cuda.is_available():
        raise NotImplementedError('Running on CPU is not supported')

    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

    if args.openfold:
        from openfold.model.evoformer import EvoformerBlock

        class OpenFoldEvoformer(nn.Module):

            def __init__(self, d_node, d_pair):
                super(OpenFoldEvoformer, self).__init__()
                self.d_node = d_node
                self.d_pair = d_pair

                self.c_hidden_msa_att = int(d_node / 8)
                self.c_hidden_pair_att = int(d_pair / 8)

                self.EvoformerBlock = EvoformerBlock(c_m=d_node,
                                                     c_z=d_pair,
                                                     c_hidden_msa_att=self.c_hidden_msa_att,
                                                     c_hidden_opm=self.c_hidden_msa_att,
                                                     c_hidden_mul=self.d_pair,
                                                     c_hidden_pair_att=self.c_hidden_pair_att,
                                                     no_heads_msa=8,
                                                     no_heads_pair=4,
                                                     transition_n=4,
                                                     msa_dropout=0.15,
                                                     pair_dropout=0.25,
                                                     inf=1e9,
                                                     eps=1e-10)

            def forward(self, node, pair, node_mask, pair_mask):
                node, pair = self.EvoformerBlock(node, pair, node_mask, pair_mask)
                return node, pair

    attn_layers = []
    for idx in range(0, args.layers):
        if args.openfold:
            attn_layers.append(OpenFoldEvoformer(d_node=args.cm, d_pair=args.cz))
        else:
LuGY's avatar
LuGY committed
87
88
89
            first_block = idx == 0
            last_block = idx == args.layers - 1
            attn_layers.append(Evoformer(c_m=args.cm, c_z=args.cz, first_block=first_block, last_block=last_block))
Shenggan's avatar
Shenggan committed
90
91
92
93
94
95
96
97
98
99
100
        attn_layers[idx].cuda()
        attn_layers[idx].to(dtype=precision)

    start_evt_fwd = []
    start_evt_bwd = []
    stop_evt_bwd = []
    for recorded_trial in range(0, args.trials):
        start_evt_fwd.append(torch.cuda.Event(enable_timing=True))
        start_evt_bwd.append(torch.cuda.Event(enable_timing=True))
        stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))

LuGY's avatar
LuGY committed
101
102
103
    batch_size = 1
    inputs_node = torch.randn(batch_size,
                              args.msa_length,
Shenggan's avatar
Shenggan committed
104
105
106
107
                              args.res_length,
                              args.cm,
                              dtype=precision,
                              device=torch.device("cuda")).requires_grad_(True)
LuGY's avatar
LuGY committed
108
109
    inputs_pair = torch.randn(batch_size,
                              args.res_length,
Shenggan's avatar
Shenggan committed
110
111
112
113
                              args.res_length,
                              args.cz,
                              dtype=precision,
                              device=torch.device("cuda")).requires_grad_(True)
LuGY's avatar
LuGY committed
114
    node_mask = torch.ones((batch_size, args.msa_length, args.res_length),
Shenggan's avatar
Shenggan committed
115
116
                           dtype=precision,
                           device=torch.device("cuda")).requires_grad_(False)
LuGY's avatar
LuGY committed
117
    pair_mask = torch.ones((batch_size, args.res_length, args.res_length),
Shenggan's avatar
Shenggan committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
                           dtype=precision,
                           device=torch.device("cuda")).requires_grad_(False)
    grads_node = torch.randn_like(inputs_pair)

    if args.prof:
        prof = torch.profiler.profile(
            schedule=torch.profiler.schedule(wait=1,
                                             warmup=args.warmup_trials,
                                             active=args.trials,
                                             repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/fastfold'),
            profile_memory=False,
            record_shapes=False,
            with_stack=False)
        prof.start()

LuGY's avatar
LuGY committed
134
135
136
137
138
139
140
    if not args.openfold:
            inputs_node = inputs_node.squeeze(0)
            inputs_pair = inputs_pair.squeeze(0)
            node_mask = node_mask.squeeze(0)
            pair_mask = pair_mask.squeeze(0)
            grads_node = grads_node.squeeze(0)

Shenggan's avatar
Shenggan committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    for trial in range(0, args.trials + args.warmup_trials):
        layer_inputs = inputs_node, inputs_pair
        evt_idx = trial - args.warmup_trials

        torch.distributed.barrier()
        torch.cuda.synchronize()

        if evt_idx >= 0:
            start_evt_fwd[evt_idx].record()

        for lyr_idx in range(0, args.layers):
            layer_inputs = attn_layers[lyr_idx].forward(*layer_inputs, node_mask, pair_mask)

        torch.cuda.synchronize()

        if evt_idx >= 0:
            start_evt_bwd[evt_idx].record()

        if not args.fwd:
            layer_inputs[1].backward(grads_node)

        if evt_idx >= 0:
            stop_evt_bwd[evt_idx].record()

        if args.prof:
            prof.step()

    if args.prof:
        prof.stop()

    torch.distributed.barrier()
    torch.cuda.synchronize()
    elapsed_time_fwd = 0.0
    elapsed_time_bwd = 0.0
    for evt_idx in range(0, args.trials):
        elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx])
        elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])

    print("[ MSA Attn ] Input: {:4d}, {:4d}, {:4d}, ({:4d} {:4d}) Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format(
LuGY's avatar
LuGY committed
180
        batch_size, args.msa_length, args.res_length,     \
Shenggan's avatar
Shenggan committed
181
182
183
184
185
186
187
        args.cm, args.cz,                                      \
        elapsed_time_fwd / ( args.trials * args.layers ),      \
        elapsed_time_bwd / ( args.trials * args.layers )))


if __name__ == '__main__':
    main()