"...qwen3-next-80b-a3b_vllm.git" did not exist on "10174eb4218c7fb0129cc07a212793559b269bc6"
test_deepspeed_evo_attention.py 13.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
18
19
"""
Unit tests to compare components of OpenFold run with the DeepSpeed memory-efficient
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
"""

20
21
22
import unittest
import numpy as np
import pickle
23
24
import torch
from torch.nn import functional as F
25

26
from openfold.data import data_transforms
27
from openfold.model.primitives import (
28
    lecun_normal_init_,
29
    Attention
30
)
31
32
from openfold.utils.tensor_utils import tensor_tree_map

33
34
from tests.config import consts
import tests.compare_utils as compare_utils
35
from tests.data_utils import random_template_feats, random_attention_inputs
36
37


38
@compare_utils.skip_unless_ds4s_installed()
39
class TestDeepSpeedKernel(unittest.TestCase):
40
    def compare_attention_types(self, use_flash=False):
41
        """Compare attention with and without using DeepSpeed Evoformer kernel."""
42
        batch_size = consts.batch_size
43
44
        n_seq = 18
        n_res = 20
45
        c_hidden = 32
46
        no_heads = 4
47
        eps = 2e-2
48

49
50
        q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
                                                      n_seq=n_seq,
51
                                                      n=n_res,
52
53
                                                      no_heads=no_heads,
                                                      c_hidden=c_hidden)
54

55
56
57
        a = Attention(
            c_hidden, c_hidden, c_hidden, c_hidden, no_heads
        ).cuda()
58

59
60
        # Change output params init for testing since they are initialized with 'final' init (zeros)
        # Otherwise both will just return zero.
61
        with torch.no_grad():
62
63
64
65
66
            lecun_normal_init_(a.linear_g.weight)
            lecun_normal_init_(a.linear_o.weight)

            if use_flash:
                biases = [biases[0]]
67
                flash_mask = mask.reshape(batch_size * n_seq, n_res)
68
69
70
71
72
                real_out = a(q, kv, use_flash=True, flash_mask=flash_mask).cpu()
            else:
                real_out = a(q, kv, biases=biases).cpu()

            ds_out = a(q, kv, biases=biases, use_deepspeed_evo_attention=True).cpu()
73

74
75
76
        err = torch.max(torch.abs(ds_out - real_out))
        self.assertTrue(err < eps, f'Error: {err}')

77
    def test_ds_kernel_vs_attention_forward(self):
78
79
80
81
        """Compare regular attention vs. DeepSpeed Evoformer kernel."""
        self.compare_attention_types(use_flash=False)

    @compare_utils.skip_unless_flash_attn_installed()
82
    def test_ds_kernel_vs_flash_attn_forward(self):
83
84
        """Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
        self.compare_attention_types(use_flash=True)
85

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    def test_ds_kernel_vs_attention_backward(self):
        """Compare backward pass for regular attention vs. DeepSpeed Evoformer kernel."""
        batch_size = consts.batch_size
        n_seq = 18
        n_res = 20
        c_hidden = 32
        no_heads = 4
        eps = consts.eps

        q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
                                                      n_seq=n_seq,
                                                      n=n_res,
                                                      no_heads=no_heads,
                                                      c_hidden=c_hidden,
                                                      requires_grad=True)

        attn = Attention(
            c_hidden, c_hidden, c_hidden, c_hidden, no_heads
        ).cuda()

        with torch.no_grad():
            lecun_normal_init_(attn.linear_g.weight)
            lecun_normal_init_(attn.linear_o.weight)

        def clone(t):
111
            # Create new params, clone values
112
113
114
115
116
117
            t = t.clone()
            if t.requires_grad:
                t.retain_grad()
            return t

        def init_attn():
118
            # Create new attention object with same initial weights
119
120
121
122
123
124
125
            a_clone = Attention(
                c_hidden, c_hidden, c_hidden, c_hidden, no_heads
            ).cuda()

            a_clone.load_state_dict(attn.state_dict())
            return a_clone

126
        # Clone param values and run attention with DS kernel
127
128
129
130
        q_repro = clone(q)
        kv_repro = clone(kv)
        biases_repro = [clone(b) for b in biases]

131
132
        a_repro = init_attn()
        out_repro = a_repro(q_repro, kv_repro, biases=biases_repro, use_deepspeed_evo_attention=True)
133
134
135
136
137
138
139
        loss_repro = torch.mean(out_repro)
        loss_repro.backward()

        q_gt = clone(q)
        kv_gt = clone(kv)
        biases_gt = [clone(b) for b in biases]

140
141
142
        # Clone param values and run attention without DS kernel
        a_gt = init_attn()
        out_gt = a_gt(q_gt, kv_gt, biases=biases_gt)
143
144
145
146

        loss_gt = torch.mean(out_gt)
        loss_gt.backward()

147
148
149
        # Compare the grads of attention inputs
        pairs = zip([q_repro, kv_repro, biases_repro[1]],
                    [q_gt, kv_gt, biases_gt[1]])
150
151
152
153
154
        for i, item in enumerate(pairs):
            t_repro, t_gt = item
            err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
            self.assertTrue(err < eps, f'Error item #{i}: {err}')

155
156
157
158
159
160
161
162
163
        # Compare the grads of model weights
        a_repro_params = dict(a_repro.named_parameters())
        a_gt_params = dict(a_gt.named_parameters())
        for name in a_gt_params.keys():
            t_repro = a_repro_params[name]
            t_gt = a_gt_params[name]
            err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
            self.assertTrue(err < eps, f'Error item {name}: {err}')

Christina Floristean's avatar
Christina Floristean committed
164
    def compare_evoformer(self, dtype, eps):
165
166
167
168
169
        """
        Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
        Set dtype to confirm the kernel can be used during both training (BF16) and inference (FP32),
        since the kernel itself can run with either BF16 or FP16 precision.
        """
170
171
        n_res = 20
        n_seq = 18
172
173
        c_m_shape = (consts.c_m,)
        c_z_shape = (consts.c_z,)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

        activations = {
            "msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
            "pair": torch.rand(n_res, n_res, consts.c_z, device='cuda', dtype=dtype)
        }

        masks = {
            "msa": torch.randint(0, 2, (n_seq, n_res), device='cuda', dtype=dtype),
            "pair": torch.randint(0, 2, (n_res, n_res), device='cuda', dtype=dtype),
        }

        with torch.cuda.amp.autocast(dtype=dtype):
            model = compare_utils.get_global_pretrained_openfold()
            out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
                activations["msa"],
                activations["pair"],
                masks["msa"],
                masks["pair"],
                use_deepspeed_evo_attention=False,
                chunk_size=4,
                _mask_trans=False,
                inplace_safe=False,
            )

198
199
200
201
            # In practice, layer norms applied later in the network make any
            # kernel rounding errors negligible
            out_repro_msa = F.layer_norm(out_repro_msa, c_m_shape).cpu()
            out_repro_pair = F.layer_norm(out_repro_pair, c_z_shape).cpu()
202
203
204
205
206
207
208
209
210
211
212

            out_repro_msa_ds, out_repro_pair_ds = model.evoformer.blocks[0](
                activations["msa"],
                activations["pair"],
                masks["msa"],
                masks["pair"],
                use_deepspeed_evo_attention=True,
                chunk_size=4,
                _mask_trans=False,
                inplace_safe=False,
            )
213
214
            out_repro_msa_ds = F.layer_norm(out_repro_msa_ds, c_m_shape).cpu()
            out_repro_pair_ds = F.layer_norm(out_repro_pair_ds, c_z_shape).cpu()
215

216
217
218
219
220
            err = torch.mean(torch.abs(out_repro_msa - out_repro_msa_ds))
            self.assertTrue(err < eps, f'MSA Error: {err}')

            err = torch.mean(torch.abs(out_repro_pair - out_repro_pair_ds))
            self.assertTrue(err < eps, f'Pair Error {err}')
221
222

    def test_compare_evoformer_bf16(self):
223
        """Run evoformer comparison test with BF16 precision."""
Christina Floristean's avatar
Christina Floristean committed
224
        self.compare_evoformer(dtype=torch.bfloat16, eps=4e-2)
225
226

    def test_compare_evoformer_fp32(self):
227
        """Run evoformer comparison test with FP32 precision."""
Christina Floristean's avatar
Christina Floristean committed
228
        self.compare_evoformer(dtype=torch.float32, eps=2e-2)
229

230
231
232
233
234
235
236
237
238
239
240
    def test_compare_template_stack(self):
        """
        Compare Template Stack output with and without using DeepSpeed Evoformer attention kernel.
        Kernel can be used for Triangle Attention in the Template Pair Stack.
        """
        n_templ = consts.n_templ
        n_res = 20
        eps = 2e-2

        batch = random_template_feats(n_templ, n_res)
        batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
241
242
243
244
        if consts.is_multimer:
            batch["asym_id"] = batch['asym_id'][0]

        pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
245
246
        pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)

247
248
249
250
251
        batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
        template_feats = {
            k: v for k, v in batch.items() if k.startswith("template_")
        }

252
253
254
255
        with torch.no_grad():
            model = compare_utils.get_global_pretrained_openfold()
            model.globals.use_deepspeed_evo_attention = False
            out_repro = model.embed_templates(
256
257
                template_feats,
                batch,
258
259
260
261
262
263
264
265
266
                torch.as_tensor(pair_act).cuda(),
                torch.as_tensor(pair_mask).cuda(),
                templ_dim=0,
                inplace_safe=False
            )
            out_repro = out_repro["template_pair_embedding"].cpu()

            model.globals.use_deepspeed_evo_attention = True
            out_repro_ds = model.embed_templates(
267
268
                template_feats,
                batch,
269
270
271
272
273
274
275
                torch.as_tensor(pair_act).cuda(),
                torch.as_tensor(pair_mask).cuda(),
                templ_dim=0,
                inplace_safe=False
            )
            out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()

276
            compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, eps)
277

278
279
280
    def test_compare_model(self):
        """
        Run full model with and without using DeepSpeed Evoformer attention kernel
281
        and compare output coordinates.
282
        """
283
        eps = 0.2
284
285
286
287
288
289
290
291
        with open("tests/test_data/sample_feats.pickle", "rb") as fp:
            batch = pickle.load(fp)

        # atom37_to_atom14 doesn't like batches
        batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
        batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]

        batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ])
292
293
294
295
296
297
298
299
300

        if consts.is_multimer:
            n_res = batch['aatype'].shape[1]
            n_extra_seq = batch['extra_msa'].shape[1]
            batch["asym_id"] = np.ones((4, n_res))
            batch["entity_id"] = np.ones((4, n_res))
            batch["sym_id"] = np.ones((4, n_res))
            batch["extra_deletion_matrix"] = np.random.randint(0, 2, size=(4, n_extra_seq, n_res))
        
301
302
303
304
305
306
307
308
        batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}

        batch["aatype"] = batch["aatype"].long()
        batch["template_aatype"] = batch["template_aatype"].long()
        batch["extra_msa"] = batch["extra_msa"].long()
        batch["residx_atom37_to_atom14"] = batch[
            "residx_atom37_to_atom14"
        ].long()
309
        batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
310
311
312
313
314
315
316
317
        batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
        batch.update(
            data_transforms.atom37_to_torsion_angles("template_")(batch)
        )

        # Move the recycling dimension to the end
        move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
        batch = tensor_tree_map(move_dim, batch)
318
        with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float32):
319
                model = compare_utils.get_global_pretrained_openfold()
320
                model.globals.use_deepspeed_evo_attention = False
321
                out_repro = model(batch)
322

323
324
325
                # Enable kernel
                model.globals.use_deepspeed_evo_attention = True
                out_repro_ds = model(batch)
326

327
328
                out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
                out_repro_ds = tensor_tree_map(lambda t: t.cpu(), out_repro_ds)
329

330
331
                out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
                out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
332

333
                compare_utils.assert_mean_abs_diff_small(out_repro, out_repro_ds, eps)
334

335
336
337

if __name__ == "__main__":
    unittest.main()