fused_attn.py 16.8 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
8
from jax.ad_checkpoint import checkpoint_name
9
10
11
import jax
import jax.numpy as jnp

12
13
14
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
15

16
from .cpp_extensions import FusedAttnHelper
17
18
from .cpp_extensions import fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked
from .cpp_extensions import fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked
19
from .cpp_extensions import fused_attn_fwd, fused_attn_bwd
20
21
22


class AttnBiasType(Enum):
23
24
25
26
27
    """
    NO_BIAS: Softmax is performed as softmax(scale * qk)
    PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
    POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
    """
28
29
30
31
32
33
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
34
35
36
37
38
39
    """
    NO_MASK: No attention mask is applied.
    PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
    CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
    PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
    """
40
41
42
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
43
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
44
45


46
47
48
49
class QKVLayout(Enum):
    """QKV layout"""
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
50
51
52
53
54
55
56
57
58
    BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD


def canonicalize_attn_mask_type(attn_mask_type: str):
    """Convert string attn_mask_type to AttnMaskType
    TE-JAX currently fall back to the padding version kernels for the libraries integration.
    The overhead between padding and non-padding version should be small.
    However, we will lease this limitation in the near feature.
    """
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    match attn_mask_type:
        case 'no_mask':
            return AttnMaskType.NO_MASK
        case 'padding':
            return AttnMaskType.PADDING_MASK
        case 'causal':
            return AttnMaskType.CAUSAL_MASK
        case 'padding_causal' | 'causal_padding':
            return AttnMaskType.PADDING_CAUSAL_MASK
    raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                     "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")


def is_fused_attn_kernel_available(q_dtype, kv_dtype, qkv_layout, attn_bias_type, attn_mask_type,
                                   dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen,
                                   kv_max_seqlen, head_dim):
75
    """
76
    To check whether the fused attention kernel is supported
77
    """
78
79
80
    return FusedAttnHelper(q_dtype, kv_dtype, qkv_layout.value, attn_bias_type.value,
                           attn_mask_type.value, dropout_probability, q_num_heads, kv_num_heads,
                           q_max_seqlen, kv_max_seqlen, head_dim).is_fused_attn_kernel_available()
81
82


83
84
85
86
def fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
                         seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
                         attn_mask_type: AttnMaskType, scaling_factor: float,
                         dropout_probability: float, is_training: bool):
87
    """
88
    Fused attention with the qkvpacked inputs
89
    """
90
91
92
93
94
95
96
97
98
    output = _fused_attn_qkvpacked(qkv,
                                   bias,
                                   mask,
                                   seed,
                                   attn_bias_type=attn_bias_type,
                                   attn_mask_type=attn_mask_type,
                                   scaling_factor=scaling_factor,
                                   dropout_probability=dropout_probability,
                                   is_training=is_training)
99
100
101
102
103

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
104
105
106
107
108
109
110
111
def _fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
                          seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
                          attn_mask_type: AttnMaskType, scaling_factor: float,
                          dropout_probability: float, is_training: bool):

    output, _ = _fused_attn_fwd_qkvpacked_rule(qkv, bias, mask, seed, attn_bias_type,
                                               attn_mask_type, scaling_factor, dropout_probability,
                                               is_training)
112
    return output
113
114


115
116
117
118
119
def _fused_attn_fwd_qkvpacked_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
                                   seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
                                   attn_mask_type: AttnMaskType, scaling_factor: float,
                                   dropout_probability: float, is_training: bool):
    if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
120
121
122
        batch, seqlen, *_ = qkv.shape
        actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
    else:
123
        assert mask is not None
124
125
        mask = jnp.logical_not(mask)
        actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
126
127
128
129
130
131
132
133
134
135
    output, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
        qkv,
        bias,
        actual_seqlen,
        seed,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training)
136
137
138
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
zlsh80826's avatar
zlsh80826 committed
139
    return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
140
141


142
143
def _fused_attn_bwd_qkvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
                                   dropout_probability, is_training, ctx, dz):
zlsh80826's avatar
zlsh80826 committed
144
    qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
145

146
147
148
149
150
151
152
153
154
155
156
157
    grad_qkv, grad_bias = fused_attn_bwd_qkvpacked(qkv,
                                                   bias,
                                                   softmax_aux,
                                                   rng_state,
                                                   output,
                                                   dz,
                                                   actual_seqlen,
                                                   attn_bias_type=attn_bias_type.value,
                                                   attn_mask_type=attn_mask_type.value,
                                                   scaling_factor=scaling_factor,
                                                   dropout_probability=dropout_probability,
                                                   is_training=is_training)
158

159
    if attn_bias_type == AttnBiasType.NO_BIAS:
160
        grad_bias = None
161
162
163
164

    return grad_qkv, grad_bias, None, None


165
_fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule)
166
167


168
169
170
171
def fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                        seed: jnp.ndarray, attn_bias_type: AttnBiasType,
                        attn_mask_type: AttnMaskType, scaling_factor: float,
                        dropout_probability: float, is_training: bool):
172
    """
173
    Fused attention with the kvpacked inputs
174
175
    """

176
177
178
179
180
181
182
183
184
185
    output = _fused_attn_kvpacked(q,
                                  kv,
                                  bias,
                                  mask,
                                  seed,
                                  attn_bias_type=attn_bias_type,
                                  attn_mask_type=attn_mask_type,
                                  scaling_factor=scaling_factor,
                                  dropout_probability=dropout_probability,
                                  is_training=is_training)
186
187
188
189

    return output


190
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
191
192
193
194
195
196
197
198
def _fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                         seed: jnp.ndarray, attn_bias_type: AttnBiasType,
                         attn_mask_type: AttnMaskType, scaling_factor: float,
                         dropout_probability: float, is_training: bool):

    output, _ = _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type,
                                              attn_mask_type, scaling_factor, dropout_probability,
                                              is_training)
199
200
201
    return output


202
203
204
def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
                                  scaling_factor, dropout_probability, is_training):
    if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
205
206
207
208
        batch, s_q, *_ = q.shape
        s_kv = kv.shape[1]
        q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
        kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
zlsh80826's avatar
zlsh80826 committed
209
    else:
210
        assert mask is not None
211
212
        mask = jnp.logical_not(mask)
        q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
213
        if attn_mask_type == AttnMaskType.PADDING_MASK:
214
215
216
217
            kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
            kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
218

219
220
221
222
223
224
225
226
227
228
229
230
    output, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
        q,
        kv,
        bias,
        q_actual_seqlen,
        kv_actual_seqlen,
        seed,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training)
231
232
233
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
zlsh80826's avatar
zlsh80826 committed
234
    return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
235
236


237
238
def _fused_attn_bwd_kvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
                                  dropout_probability, is_training, ctx, dz):
zlsh80826's avatar
zlsh80826 committed
239
    q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
    grad_q, grad_kv, grad_bias = fused_attn_bwd_kvpacked(q,
                                                         kv,
                                                         bias,
                                                         softmax_aux,
                                                         rng_state,
                                                         output,
                                                         dz,
                                                         q_actual_seqlen,
                                                         kv_actual_seqlen,
                                                         attn_bias_type=attn_bias_type.value,
                                                         attn_mask_type=attn_mask_type.value,
                                                         scaling_factor=scaling_factor,
                                                         dropout_probability=dropout_probability,
                                                         is_training=is_training)
255
256
257
258
259

    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None

    return grad_q, grad_kv, grad_bias, None, None
260
261


262
_fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule)
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299


def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
               seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
               scaling_factor: float, dropout_probability: float, is_training: bool):
    """
    Dot product attention with the seperated query, key, value
    """

    output = _fused_attn(q,
                         k,
                         v,
                         bias,
                         mask,
                         seed,
                         attn_bias_type=attn_bias_type,
                         attn_mask_type=attn_mask_type,
                         scaling_factor=scaling_factor,
                         dropout_probability=dropout_probability,
                         is_training=is_training)

    return output


@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10))
def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
                mask: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
                attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float,
                is_training: bool):

    output, _ = _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type,
                                     scaling_factor, dropout_probability, is_training)
    return output


def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
                         dropout_probability, is_training):
300
    if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
301
302
303
304
305
        batch, s_q, *_ = q.shape
        s_kv = k.shape[1]
        q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
        kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
    else:
306
        assert mask is not None
307
308
        mask = jnp.logical_not(mask)
        q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
309
        if attn_mask_type == AttnMaskType.PADDING_MASK:
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
            kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
            kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))

    output, softmax_aux, rng_state = fused_attn_fwd(q,
                                                    k,
                                                    v,
                                                    bias,
                                                    q_actual_seqlen,
                                                    kv_actual_seqlen,
                                                    seed,
                                                    attn_bias_type=attn_bias_type.value,
                                                    attn_mask_type=attn_mask_type.value,
                                                    scaling_factor=scaling_factor,
                                                    dropout_probability=dropout_probability,
                                                    is_training=is_training)
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
    return output, (q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen,
                    kv_actual_seqlen)


def _fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                         is_training, ctx, dz):
    q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx

    grad_q, grad_k, grad_v, grad_bias = fused_attn_bwd(q,
                                                       k,
                                                       v,
                                                       bias,
                                                       softmax_aux,
                                                       rng_state,
                                                       output,
                                                       dz,
                                                       q_actual_seqlen,
                                                       kv_actual_seqlen,
                                                       attn_bias_type=attn_bias_type.value,
                                                       attn_mask_type=attn_mask_type.value,
                                                       scaling_factor=scaling_factor,
                                                       dropout_probability=dropout_probability,
                                                       is_training=is_training)

    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None

    return grad_q, grad_k, grad_v, grad_bias, None, None


_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)