test_distributed_fused_attn.py 12.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

import pytest
6
from functools import partial
7
8
9
10
11
12

import jax
import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
13
from jax.sharding import Mesh, NamedSharding, PartitionSpec
14
15
16
17
18
19
from distributed_test_base import (
    generate_configs,
    generate_context_parallel_configs,
    generate_collectives_count,
    compare_ops,
)
20
21
22
23
24
25
from utils import (
    make_causal_mask,
    make_self_mask,
    assert_allclose,
    print_debug_tensor_stats,
)
26
from transformer_engine.jax import fp8_autocast
27
28
from transformer_engine.jax.attention import (
    is_fused_attn_kernel_available,
29
    fused_attn,
30
31
    AttnBiasType,
    AttnMaskType,
32
    QKVLayout,
33
34
35
    QKVFormat,
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
36
    CPStrategy,
37
)
38
from transformer_engine.jax.sharding import MeshResource
39

40
from test_fused_attn import FusedAttnRunner, BiasShape, general_dot_product_attention, make_mask
41
42
43
44
45
46

DTYPES = [jnp.float16, jnp.bfloat16]


class TestDistributedSelfAttn:

47
48
49
    def generate_collectives_count_ref(
        self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
    ):
50
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
51
        _, seqlen, heads, _ = shape
52
53
54
55
56
57
        is_dp_enabled = mesh_resource.dp_resource is not None
        tp_size = 1
        if mesh_resource.tp_resource is not None:
            idx = mesh_axes.index(mesh_resource.tp_resource)
            tp_size = mesh_shape[idx]

58
        all_reduce_loss_bytes = 4  # 1 * FP32
59
60
61
62
63
        bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
        allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
        # for loss and dbias
        return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

64
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
65
    @pytest.mark.parametrize(
66
67
68
69
70
        "data_shape",
        [
            pytest.param((32, 512, 12, 64), id="32-512-12-64"),
            pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
        ],
71
72
    )
    @pytest.mark.parametrize(
73
74
75
76
77
78
79
80
81
82
83
84
85
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
    @pytest.mark.parametrize(
        "attn_mask_type",
        [
            pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
            pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
        ],
86
87
88
89
90
91
92
93
94
95
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_self_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
96
        bias_shape,
97
98
99
        attn_mask_type,
        dtype,
    ):
100
101
102
        dropout_prob = 0.0
        is_training = True

103
        batch, seqlen, num_head, hidden = data_shape
104

105
106
107
108
109
110
111
112
113
114
115
116
        if not is_fused_attn_kernel_available(
            dtype,
            dtype,
            QKVLayout.BS3HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
117
            None,  # no window
118
        ):
119
            pytest.skip(f"No FusedAttn backend found")
120

121
122
123
124
125
126
127
        col_ref = self.generate_collectives_count_ref(
            mesh_shape,
            mesh_axes,
            mesh_resource,
            attn_bias_type != AttnBiasType.NO_BIAS,
            data_shape,
            dtype,
128
        )
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BS3HD,
            bias_shape,
            None,
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
149
        )
150
        runner.test_backward()
151
152
153
154
155
156


class TestDistributedCrossAttn:

    def generate_collectives_count_ref(self):
        # for loss
157
        all_reduce_loss_bytes = 4  # 1 * FP32
158
159
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

160
161
162
163
164
165
166
167
168
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_cross_attn(
        self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
    ):
169
        attn_bias_type = AttnBiasType.NO_BIAS
170
        bias_shape = None
171
172
173
        dropout_prob = 0.0
        is_training = True

174
        batch, seqlen, num_head, hidden = data_shape
175

176
177
178
179
180
181
182
183
184
185
186
187
        if not is_fused_attn_kernel_available(
            dtype,
            dtype,
            QKVLayout.BSHD_BS2HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
188
            None,  # no window
189
        ):
190
            pytest.skip(f"No FusedAttn backend found")
191

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        col_ref = self.generate_collectives_count_ref()
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BSHD_BS2HD,
            bias_shape,
            None,
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
213
        )
214
        runner.test_backward()
215
216


217
218
219
220
221
222
@pytest.mark.parametrize(
    "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
    "data_shape",
    [
223
224
225
        # Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
        pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"),
        pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    ],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
    ],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
    ],
)
@pytest.mark.parametrize(
    "load_balanced",
    [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
248
class TestDistributedContextParallelSelfAttn:
249

250
    def impl_test_context_parallel_attn(
251
252
253
254
255
256
257
258
259
260
261
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
262
        cp_strategy,
263
264
    ):
        attn_bias_type = AttnBiasType.NO_BIAS
265
        bias_shape = None
266
267
268
        dropout_prob = 0.0
        is_training = True
        dp_size, cp_size, tp_size = mesh_shape
269
        qkv_format = qkv_layout.get_qkv_format()
270

271
        batch, seqlen, num_head, hidden = data_shape
272
273
274
275
276
277

        # Scale the sequence length by 2*CP so its never too small as we scale up test.
        # 2*CP is used since we split into two CP groups for load balancing.
        seqlen = seqlen * cp_size * 2
        data_shape = batch, seqlen, num_head, hidden

278
        num_kv_heads = num_head // kv_groups
279
        scaling_factor = 1.0 / np.sqrt(num_head)
280

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_kv_heads,
            hidden,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
            None,
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            cp_strategy=cp_strategy,
            cp_load_balanced=load_balanced,
        )

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        def check_has_backend_for_mask(mask_type):
            return is_fused_attn_kernel_available(
                dtype,
                dtype,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
                dropout_prob,
                num_head,
                num_kv_heads,
                seqlen,
                seqlen,
                hidden,
                None,
            )  # no SWA for CP

        # For causal masking we depend on having bottom right support also.
        # The API does not check this and instead we rely on lower level checks to raise
        # and exception if the step backend is not supported. This was a deliberate API
        # decision to keep the CP size or flag out of the function.
        has_backend = check_has_backend_for_mask(attn_mask_type)
        if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
            has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

        if not has_backend:
            pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
330

331
332
333
        if dp_size > 1 and batch % dp_size != 0:
            pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")

334
        # make sure the mesh even divides cp and tp axis
335
336
337
        if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
            pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")

338
        runner.test_backward()
339

340
    def test_context_parallel_allgather_attn(
341
342
343
344
345
346
347
348
349
350
351
352
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
353
        return self.impl_test_context_parallel_attn(
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.ALL_GATHER,
        )

    def test_context_parallel_ring_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
380
        return self.impl_test_context_parallel_attn(
381
382
383
384
385
386
387
388
389
390
391
392
393
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.RING,
        )

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

class TestReorderCausalLoadBalancing:
    @pytest.mark.parametrize("cp_size", [2, 4, 8])
    @pytest.mark.parametrize(
        "shape",
        [
            pytest.param([1, 16, 1, 1], id="1-16-1-1"),
            pytest.param([4, 32, 12, 32], id="4-32-12-32"),
            pytest.param([3, 32, 8, 64], id="3-32-8-64"),
        ],
    )
    @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
    def test(self, cp_size, shape, qkv_format):
        tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
        if qkv_format == QKVFormat.SBHD:
            tensor = tensor.swapaxes(0, 1)

        ref = tensor.copy()

        reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
        inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])

        reordered = reorder(tensor, cp_size, qkv_format)
        inversed = inverse(reordered, cp_size, qkv_format)

        assert jnp.array_equal(inversed, ref)