test_distributed_fused_attn.py 12.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
9
10
11
12
from distributed_test_base import (
    generate_configs,
    generate_context_parallel_configs,
    generate_collectives_count,
13
)
14
from transformer_engine.jax import fp8_autocast
15
16
17
18
from transformer_engine.jax.attention import (
    is_fused_attn_kernel_available,
    AttnBiasType,
    AttnMaskType,
19
    QKVLayout,
20
21
22
    QKVFormat,
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
23
    CPStrategy,
24
)
25
from transformer_engine.jax.sharding import MeshResource
26
import pytest
27

28
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
29

30
DTYPES = [jnp.bfloat16]
31
32
33
34


class TestDistributedSelfAttn:

35
36
37
    def generate_collectives_count_ref(
        self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
    ):
38
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
39
        _, seqlen, heads, _ = shape
40
41
42
43
44
45
        is_dp_enabled = mesh_resource.dp_resource is not None
        tp_size = 1
        if mesh_resource.tp_resource is not None:
            idx = mesh_axes.index(mesh_resource.tp_resource)
            tp_size = mesh_shape[idx]

46
        all_reduce_loss_bytes = 4  # 1 * FP32
47
48
49
50
51
        bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
        allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
        # for loss and dbias
        return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

52
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
53
    @pytest.mark.parametrize(
54
55
56
57
58
        "data_shape",
        [
            pytest.param((32, 512, 12, 64), id="32-512-12-64"),
            pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
        ],
59
60
    )
    @pytest.mark.parametrize(
61
62
63
64
65
66
67
68
69
70
71
72
73
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
    @pytest.mark.parametrize(
        "attn_mask_type",
        [
            pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
            pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
        ],
74
75
76
77
78
79
80
81
82
83
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_self_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        attn_bias_type,
84
        bias_shape,
85
86
87
        attn_mask_type,
        dtype,
    ):
88
89
90
        dropout_prob = 0.0
        is_training = True

91
        batch, seqlen, num_head, hidden = data_shape
92

93
94
95
96
97
98
99
100
101
102
103
104
        if not is_fused_attn_kernel_available(
            dtype,
            dtype,
            QKVLayout.BS3HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
105
            None,  # no window
106
        ):
107
            pytest.skip(f"No FusedAttn backend found")
108

109
110
111
112
113
114
115
        col_ref = self.generate_collectives_count_ref(
            mesh_shape,
            mesh_axes,
            mesh_resource,
            attn_bias_type != AttnBiasType.NO_BIAS,
            data_shape,
            dtype,
116
        )
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BS3HD,
            bias_shape,
            None,
132
            SeqDescFormat.Seqlens,
133
134
135
136
137
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
138
        )
139
        runner.test_backward()
140
141
142
143
144
145


class TestDistributedCrossAttn:

    def generate_collectives_count_ref(self):
        # for loss
146
        all_reduce_loss_bytes = 4  # 1 * FP32
147
148
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

149
150
151
152
153
154
155
156
157
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
    @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
    @pytest.mark.parametrize(
        "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
    )
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_cross_attn(
        self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
    ):
158
        attn_bias_type = AttnBiasType.NO_BIAS
159
        bias_shape = None
160
161
162
        dropout_prob = 0.0
        is_training = True

163
        batch, seqlen, num_head, hidden = data_shape
164

165
166
167
168
169
170
171
172
173
174
175
176
        if not is_fused_attn_kernel_available(
            dtype,
            dtype,
            QKVLayout.BSHD_BS2HD,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            num_head,
            num_head,
            seqlen,
            seqlen,
            hidden,
177
            None,  # no window
178
        ):
179
            pytest.skip(f"No FusedAttn backend found")
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        col_ref = self.generate_collectives_count_ref()
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_head,
            hidden,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            QKVLayout.BSHD_BS2HD,
            bias_shape,
            None,
197
            SeqDescFormat.Seqlens,
198
199
200
201
202
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            coll_count_ref=col_ref,
203
        )
204
        runner.test_backward()
205
206


207
208
209
210
211
212
@pytest.mark.parametrize(
    "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
    "data_shape",
    [
213
214
215
        # Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
        pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"),
        pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    ],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
    ],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
    ],
)
@pytest.mark.parametrize(
    "load_balanced",
    [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
238
class TestDistributedContextParallelSelfAttn:
239

240
    def impl_test_context_parallel_attn(
241
242
243
244
245
246
247
248
249
250
251
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
252
        cp_strategy,
253
254
    ):
        attn_bias_type = AttnBiasType.NO_BIAS
255
        bias_shape = None
256
257
258
        dropout_prob = 0.0
        is_training = True
        dp_size, cp_size, tp_size = mesh_shape
259
        qkv_format = qkv_layout.get_qkv_format()
260

261
        batch, seqlen, num_head, hidden = data_shape
262
263
264
265
266
267

        # Scale the sequence length by 2*CP so its never too small as we scale up test.
        # 2*CP is used since we split into two CP groups for load balancing.
        seqlen = seqlen * cp_size * 2
        data_shape = batch, seqlen, num_head, hidden

268
        num_kv_heads = num_head // kv_groups
269
        scaling_factor = 1.0 / np.sqrt(num_head)
270

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        runner = FusedAttnRunner(
            batch,
            seqlen,
            seqlen,
            num_head,
            num_kv_heads,
            hidden,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
            None,
286
            SeqDescFormat.Seqlens,
287
288
289
290
291
292
293
294
            number_of_devices=device_count,
            mesh_shape=mesh_shape,
            mesh_axes=mesh_axes,
            mesh_resource=mesh_resource,
            cp_strategy=cp_strategy,
            cp_load_balanced=load_balanced,
        )

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        def check_has_backend_for_mask(mask_type):
            return is_fused_attn_kernel_available(
                dtype,
                dtype,
                qkv_layout,
                attn_bias_type,
                attn_mask_type,
                dropout_prob,
                num_head,
                num_kv_heads,
                seqlen,
                seqlen,
                hidden,
                None,
            )  # no SWA for CP

        # For causal masking we depend on having bottom right support also.
        # The API does not check this and instead we rely on lower level checks to raise
        # and exception if the step backend is not supported. This was a deliberate API
        # decision to keep the CP size or flag out of the function.
        has_backend = check_has_backend_for_mask(attn_mask_type)
        if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
            has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

        if not has_backend:
            pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
321

322
323
324
        if dp_size > 1 and batch % dp_size != 0:
            pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")

325
        # make sure the mesh even divides cp and tp axis
326
327
328
        if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
            pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")

329
        runner.test_backward()
330

331
    def test_context_parallel_allgather_attn(
332
333
334
335
336
337
338
339
340
341
342
343
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
344
        return self.impl_test_context_parallel_attn(
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.ALL_GATHER,
        )

    def test_context_parallel_ring_attn(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        kv_groups,
        attn_mask_type,
        dtype,
        qkv_layout,
        load_balanced,
    ):
371
        return self.impl_test_context_parallel_attn(
372
373
374
375
376
377
378
379
380
381
382
383
384
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
            kv_groups,
            attn_mask_type,
            dtype,
            qkv_layout,
            load_balanced,
            CPStrategy.RING,
        )

385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410

class TestReorderCausalLoadBalancing:
    @pytest.mark.parametrize("cp_size", [2, 4, 8])
    @pytest.mark.parametrize(
        "shape",
        [
            pytest.param([1, 16, 1, 1], id="1-16-1-1"),
            pytest.param([4, 32, 12, 32], id="4-32-12-32"),
            pytest.param([3, 32, 8, 64], id="3-32-8-64"),
        ],
    )
    @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
    def test(self, cp_size, shape, qkv_format):
        tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
        if qkv_format == QKVFormat.SBHD:
            tensor = tensor.swapaxes(0, 1)

        ref = tensor.copy()

        reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
        inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])

        reordered = reorder(tensor, cp_size, qkv_format)
        inversed = inverse(reordered, cp_size, qkv_format)

        assert jnp.array_equal(inversed, ref)