test_distributed_fused_attn.py 10.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.

import pytest

import jax
import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
12
13
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
14
15
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
16
17
18
19
20
21
from transformer_engine.jax.attention import (
    is_fused_attn_kernel_available,
    fused_attn_qkvpacked,
    fused_attn_kvpacked,
    AttnBiasType,
    AttnMaskType,
22
    QKVLayout,
23
24
)

25
26
27
28
29
30

DTYPES = [jnp.float16, jnp.bfloat16]


class TestDistributedSelfAttn:

31
32
33
    def generate_collectives_count_ref(
        self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
    ):
34
35
36
37
38
39
40
41
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
        _, seqlen, _, heads, _ = shape
        is_dp_enabled = mesh_resource.dp_resource is not None
        tp_size = 1
        if mesh_resource.tp_resource is not None:
            idx = mesh_axes.index(mesh_resource.tp_resource)
            tp_size = mesh_shape[idx]

42
        all_reduce_loss_bytes = 4  # 1 * FP32
43
44
45
46
47
48
49
50
51
52
        bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
        allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
        # for loss and dbias
        return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

    def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype):
        batch, seqlen, _, heads, _ = shape

        qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)

53
54
55
56
57
        bias = (
            random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
            if with_bias
            else None
        )
58
59
60
61
62
63
64

        mask = None
        if attn_mask_type == AttnMaskType.PADDING_MASK:
            mask = make_causal_mask(batch, seqlen)
        elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
            mask = make_self_mask(batch, seqlen)

65
66
67
68
69
70
71
72
73
74
75
        qkv_pspec = PartitionSpec(
            mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
        )
        bias_pspec = (
            PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
        )
        mask_pspec = (
            PartitionSpec(mesh_resource.dp_resource, None, None, None)
            if attn_mask_type != AttnMaskType.NO_MASK
            else None
        )
76
77
78

        return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)

79
80
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
81
    @pytest.mark.parametrize(
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        "attn_bias_type",
        [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
    )
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_self_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
        attn_mask_type,
        dtype,
    ):
100
101
102
103
        dropout_prob = 0.0
        is_training = True
        scaling_factor = 1.0

104
        _, seqlen, _, num_head, hidden = data_shape
105

106
107
108
109
110
111
112
113
114
115
116
117
118
        if not is_fused_attn_kernel_available(
            dtype,
            dtype,
            QKVLayout.BS3HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
        ):
119
120
121
122
            pytest.skip(f"No FusedAttn backwend found")

        def target_func(qkv, bias, mask):
            return jnp.mean(
123
124
125
126
127
128
129
130
131
132
133
134
                fused_attn_qkvpacked(
                    qkv,
                    bias,
                    mask,
                    None,
                    attn_bias_type=attn_bias_type,
                    attn_mask_type=attn_mask_type,
                    scaling_factor=scaling_factor,
                    dropout_probability=dropout_prob,
                    is_training=is_training,
                )
            )
135
136
137
138
139
140
141

        def ref_func(qkv, bias, mask):
            query, key, value = jnp.split(qkv, [1, 2], axis=-3)
            query = jnp.squeeze(query)
            key = jnp.squeeze(key)
            value = jnp.squeeze(value)

142
143
144
145
146
147
148
149
150
151
152
            output = dot_product_attention(
                query,
                key,
                value,
                bias=bias,
                mask=mask,
                deterministic=is_training,
                dropout_rate=dropout_prob,
                dropout_rng=None,
                dtype=jnp.float32,
            )
153
154
155
156

            return jnp.mean(output).astype(dtype)

        with_bias = attn_bias_type != AttnBiasType.NO_BIAS
157
158
159
160
161
162
        (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
            data_shape, mesh_resource, with_bias, attn_mask_type, dtype
        )
        collective_count_ref = self.generate_collectives_count_ref(
            mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
        )
163
164
165
166
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(mesh_resource=mesh_resource):
            qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
167
168
169
170
171
172
            bias_ = (
                jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
            )
            mask_ = (
                jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
            )
173
174
175
176

            grad_args = (0, 1) if with_bias else (0,)
            out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)

177
178
179
180
181
182
183
184
185
186
187
            compare_ops(
                target_func,
                ref_func,
                [qkv_, bias_, mask_],
                collective_count_ref,
                grad_args=grad_args,
                metric_fwd_dtype=dtype,
                metric_bwd_dtype=dtype,
                in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
                out_shardings=(None, out_grad_shardings),
            )
188
189
190
191
192
193


class TestDistributedCrossAttn:

    def generate_collectives_count_ref(self):
        # for loss
194
        all_reduce_loss_bytes = 4  # 1 * FP32
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

    def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype):
        batch, seqlen, heads, hidden = shape

        q = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
        kv = random.normal(random.PRNGKey(1125), (batch, seqlen, 2, heads, hidden), dtype=dtype)

        mask = None
        if attn_mask_type == AttnMaskType.PADDING_MASK:
            mask = make_causal_mask(batch, seqlen)
        elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
            mask = make_self_mask(batch, seqlen)

        q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)

211
212
213
214
215
216
217
218
        kv_pspec = PartitionSpec(
            mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
        )
        mask_pspec = (
            PartitionSpec(mesh_resource.dp_resource, None, None, None)
            if attn_mask_type != AttnMaskType.NO_MASK
            else None
        )
219
220
221

        return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)

222
223
224
225
226
227
228
229
230
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_cross_attn(
        self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
    ):
231
232
233
234
235
        attn_bias_type = AttnBiasType.NO_BIAS
        dropout_prob = 0.0
        is_training = True
        scaling_factor = 1.0

236
        _, seqlen, num_head, hidden = data_shape
237

238
239
240
241
242
243
244
245
246
247
248
249
250
        if not is_fused_attn_kernel_available(
            dtype,
            dtype,
            QKVLayout.BSHD_BS2HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
        ):
251
252
253
254
            pytest.skip(f"No FusedAttn backwend found")

        def target_func(q, kv, mask):
            return jnp.mean(
255
256
257
258
259
260
261
262
263
264
265
266
267
                fused_attn_kvpacked(
                    q,
                    kv,
                    None,
                    mask,
                    None,
                    attn_bias_type=attn_bias_type,
                    attn_mask_type=attn_mask_type,
                    scaling_factor=scaling_factor,
                    dropout_probability=dropout_prob,
                    is_training=is_training,
                )
            )
268
269
270
271
272
273
274

        def ref_func(query, kv, mask):
            key, value = jnp.split(kv, [1], axis=-3)
            query = jnp.squeeze(query)
            key = jnp.squeeze(key)
            value = jnp.squeeze(value)

275
276
277
278
279
280
281
282
283
284
285
            output = dot_product_attention(
                query,
                key,
                value,
                bias=None,
                mask=mask,
                deterministic=is_training,
                dropout_rate=dropout_prob,
                dropout_rng=None,
                dtype=jnp.float32,
            )
286
287
288

            return jnp.mean(output).astype(dtype)

289
290
291
        (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
            data_shape, mesh_resource, attn_mask_type, dtype
        )
292
293
294
295
296
297
        collective_count_ref = self.generate_collectives_count_ref()
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(mesh_resource=mesh_resource):
            q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
            kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
            mask_ = (
                jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
            )

            compare_ops(
                target_func,
                ref_func,
                [q_, kv_, mask_],
                collective_count_ref,
                grad_args=(0, 1),
                metric_fwd_dtype=dtype,
                metric_bwd_dtype=dtype,
                in_shardings=(q_pspec, kv_pspec, mask_pspec),
                out_shardings=(None, (q_pspec, kv_pspec)),
            )