test_distributed_fused_attn.py 11.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# See LICENSE for license information.

import pytest

import jax
import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
19
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout

DTYPES = [jnp.float16, jnp.bfloat16]


class TestDistributedSelfAttn:

    def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape,
                                       dtype):
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
        _, seqlen, _, heads, _ = shape
        is_dp_enabled = mesh_resource.dp_resource is not None
        tp_size = 1
        if mesh_resource.tp_resource is not None:
            idx = mesh_axes.index(mesh_resource.tp_resource)
            tp_size = mesh_shape[idx]

        all_reduce_loss_bytes = 4    # 1 * FP32
        bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
        allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
        # for loss and dbias
        return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

    def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype):
        batch, seqlen, _, heads, _ = shape

        qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)

        bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \
                if with_bias else None

        mask = None
        if attn_mask_type == AttnMaskType.PADDING_MASK:
            mask = make_causal_mask(batch, seqlen)
        elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
            mask = make_self_mask(batch, seqlen)

        qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
                                  None)
        bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \
                     if with_bias else None
        mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
                     if attn_mask_type != AttnMaskType.NO_MASK else None

        return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)

    @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
    @pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
    @pytest.mark.parametrize(
        'attn_bias_type',
        [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS])
    @pytest.mark.parametrize('attn_mask_type',
                             [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
    @pytest.mark.parametrize('dtype', DTYPES)
    def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
                       attn_bias_type, attn_mask_type, dtype):
        dropout_prob = 0.0
        is_training = True
        scaling_factor = 1.0

80
        _, seqlen, _, num_head, hidden = data_shape
81
82

        if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
83
84
                                              attn_mask_type, dropout_prob, num_head, num_head,
                                              seqlen, seqlen, hidden):
85
86
87
88
            pytest.skip(f"No FusedAttn backwend found")

        def target_func(qkv, bias, mask):
            return jnp.mean(
89
90
91
92
93
94
95
96
97
                fused_attn_qkvpacked(qkv,
                                     bias,
                                     mask,
                                     None,
                                     attn_bias_type=attn_bias_type,
                                     attn_mask_type=attn_mask_type,
                                     scaling_factor=scaling_factor,
                                     dropout_probability=dropout_prob,
                                     is_training=is_training))
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

        def ref_func(qkv, bias, mask):
            query, key, value = jnp.split(qkv, [1, 2], axis=-3)
            query = jnp.squeeze(query)
            key = jnp.squeeze(key)
            value = jnp.squeeze(value)

            output = dot_product_attention(query,
                                           key,
                                           value,
                                           bias=bias,
                                           mask=mask,
                                           deterministic=is_training,
                                           dropout_rate=dropout_prob,
                                           dropout_rng=None,
                                           dtype=jnp.float32)

            return jnp.mean(output).astype(dtype)

        with_bias = attn_bias_type != AttnBiasType.NO_BIAS
        (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \
                self.generate_inputs(data_shape, mesh_resource, with_bias,
                                     attn_mask_type, dtype)
        collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes,
                                                                   mesh_resource, with_bias,
                                                                   data_shape, dtype)
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(mesh_resource=mesh_resource):
            qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
            bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \
                    if bias is not None else bias
            mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
                    if mask is not None else mask

            grad_args = (0, 1) if with_bias else (0,)
            out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)

            compare_ops(target_func,
                        ref_func, [qkv_, bias_, mask_],
                        collective_count_ref,
                        grad_args=grad_args,
                        metric_fwd_dtype=dtype,
                        metric_bwd_dtype=dtype,
                        in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
                        out_shardings=(None, out_grad_shardings))


class TestDistributedCrossAttn:

    def generate_collectives_count_ref(self):
        # for loss
        all_reduce_loss_bytes = 4    # 1 * FP32
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

    def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype):
        batch, seqlen, heads, hidden = shape

        q = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
        kv = random.normal(random.PRNGKey(1125), (batch, seqlen, 2, heads, hidden), dtype=dtype)

        mask = None
        if attn_mask_type == AttnMaskType.PADDING_MASK:
            mask = make_causal_mask(batch, seqlen)
        elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
            mask = make_self_mask(batch, seqlen)

        q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)

        kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
                                 None)
        mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
                     if attn_mask_type != AttnMaskType.NO_MASK else None

        return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)

    @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
    @pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]])
    @pytest.mark.parametrize('attn_mask_type',
                             [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
    @pytest.mark.parametrize('dtype', DTYPES)
    def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
                        attn_mask_type, dtype):
        attn_bias_type = AttnBiasType.NO_BIAS
        dropout_prob = 0.0
        is_training = True
        scaling_factor = 1.0

186
        _, seqlen, num_head, hidden = data_shape
187
188

        if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type,
189
190
                                              attn_mask_type, dropout_prob, num_head, num_head,
                                              seqlen, seqlen, hidden):
191
192
193
194
            pytest.skip(f"No FusedAttn backwend found")

        def target_func(q, kv, mask):
            return jnp.mean(
195
196
197
198
199
200
201
202
203
204
                fused_attn_kvpacked(q,
                                    kv,
                                    None,
                                    mask,
                                    None,
                                    attn_bias_type=attn_bias_type,
                                    attn_mask_type=attn_mask_type,
                                    scaling_factor=scaling_factor,
                                    dropout_probability=dropout_prob,
                                    is_training=is_training))
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

        def ref_func(query, kv, mask):
            key, value = jnp.split(kv, [1], axis=-3)
            query = jnp.squeeze(query)
            key = jnp.squeeze(key)
            value = jnp.squeeze(value)

            output = dot_product_attention(query,
                                           key,
                                           value,
                                           bias=None,
                                           mask=mask,
                                           deterministic=is_training,
                                           dropout_rate=dropout_prob,
                                           dropout_rng=None,
                                           dtype=jnp.float32)

            return jnp.mean(output).astype(dtype)

        (q, kv, mask), (q_pspec, kv_pspec,  mask_pspec) = \
                self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype)
        collective_count_ref = self.generate_collectives_count_ref()
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(mesh_resource=mesh_resource):
            q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
            kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
            mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
                    if mask is not None else mask

            compare_ops(target_func,
                        ref_func, [q_, kv_, mask_],
                        collective_count_ref,
                        grad_args=(0, 1),
                        metric_fwd_dtype=dtype,
                        metric_bwd_dtype=dtype,
                        in_shardings=(q_pspec, kv_pspec, mask_pspec),
                        out_shardings=(None, (q_pspec, kv_pspec)))