attention.py 16.6 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
8
from jax.ad_checkpoint import checkpoint_name
9
10
11
import jax
import jax.numpy as jnp

12
13
14
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
15

16
from . import cpp_extensions as tex
17
18
19


class AttnBiasType(Enum):
20
21
22
23
24
    """
    NO_BIAS: Softmax is performed as softmax(scale * qk)
    PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
    POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
    """
25
26
27
28
29
30
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
31
32
33
34
35
36
    """
    NO_MASK: No attention mask is applied.
    PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
    CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
    PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
    """
37
38
39
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
40
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
41
42


43
44
45
46
class QKVLayout(Enum):
    """QKV layout"""
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
47
48
49
50
51
52
53
54
55
    BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD


def canonicalize_attn_mask_type(attn_mask_type: str):
    """Convert string attn_mask_type to AttnMaskType
    TE-JAX currently fall back to the padding version kernels for the libraries integration.
    The overhead between padding and non-padding version should be small.
    However, we will lease this limitation in the near feature.
    """
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    match attn_mask_type:
        case 'no_mask':
            return AttnMaskType.NO_MASK
        case 'padding':
            return AttnMaskType.PADDING_MASK
        case 'causal':
            return AttnMaskType.CAUSAL_MASK
        case 'padding_causal' | 'causal_padding':
            return AttnMaskType.PADDING_CAUSAL_MASK
    raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                     "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")


def is_fused_attn_kernel_available(q_dtype, kv_dtype, qkv_layout, attn_bias_type, attn_mask_type,
                                   dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen,
                                   kv_max_seqlen, head_dim):
72
    """
73
    To check whether the fused attention kernel is supported
74
    """
75
    return tex.FusedAttnHelper(q_dtype, kv_dtype, qkv_layout.value, attn_bias_type.value,
76
77
                           attn_mask_type.value, dropout_probability, q_num_heads, kv_num_heads,
                           q_max_seqlen, kv_max_seqlen, head_dim).is_fused_attn_kernel_available()
78
79


80
81
82
83
def fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
                         seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
                         attn_mask_type: AttnMaskType, scaling_factor: float,
                         dropout_probability: float, is_training: bool):
84
    """
85
    Fused attention with the qkvpacked inputs
86
    """
87
88
89
90
91
92
93
94
95
    output = _fused_attn_qkvpacked(qkv,
                                   bias,
                                   mask,
                                   seed,
                                   attn_bias_type=attn_bias_type,
                                   attn_mask_type=attn_mask_type,
                                   scaling_factor=scaling_factor,
                                   dropout_probability=dropout_probability,
                                   is_training=is_training)
96
97
98
99
100

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
101
102
103
104
105
106
107
108
def _fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
                          seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
                          attn_mask_type: AttnMaskType, scaling_factor: float,
                          dropout_probability: float, is_training: bool):

    output, _ = _fused_attn_fwd_qkvpacked_rule(qkv, bias, mask, seed, attn_bias_type,
                                               attn_mask_type, scaling_factor, dropout_probability,
                                               is_training)
109
    return output
110
111


112
113
114
115
116
def _fused_attn_fwd_qkvpacked_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
                                   seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
                                   attn_mask_type: AttnMaskType, scaling_factor: float,
                                   dropout_probability: float, is_training: bool):
    if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
117
118
119
        batch, seqlen, *_ = qkv.shape
        actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
    else:
120
        assert mask is not None
121
122
        mask = jnp.logical_not(mask)
        actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
123
    output, softmax_aux, rng_state = tex.fused_attn_fwd_qkvpacked(
124
125
126
127
128
129
130
131
132
        qkv,
        bias,
        actual_seqlen,
        seed,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training)
133
134
135
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
zlsh80826's avatar
zlsh80826 committed
136
    return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
137
138


139
140
def _fused_attn_bwd_qkvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
                                   dropout_probability, is_training, ctx, dz):
zlsh80826's avatar
zlsh80826 committed
141
    qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
142

143
    grad_qkv, grad_bias = tex.fused_attn_bwd_qkvpacked(qkv,
144
145
146
147
148
149
150
151
152
153
154
                                                   bias,
                                                   softmax_aux,
                                                   rng_state,
                                                   output,
                                                   dz,
                                                   actual_seqlen,
                                                   attn_bias_type=attn_bias_type.value,
                                                   attn_mask_type=attn_mask_type.value,
                                                   scaling_factor=scaling_factor,
                                                   dropout_probability=dropout_probability,
                                                   is_training=is_training)
155

156
    if attn_bias_type == AttnBiasType.NO_BIAS:
157
        grad_bias = None
158
159
160
161

    return grad_qkv, grad_bias, None, None


162
_fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule)
163
164


165
166
167
168
def fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                        seed: jnp.ndarray, attn_bias_type: AttnBiasType,
                        attn_mask_type: AttnMaskType, scaling_factor: float,
                        dropout_probability: float, is_training: bool):
169
    """
170
    Fused attention with the kvpacked inputs
171
172
    """

173
174
175
176
177
178
179
180
181
182
    output = _fused_attn_kvpacked(q,
                                  kv,
                                  bias,
                                  mask,
                                  seed,
                                  attn_bias_type=attn_bias_type,
                                  attn_mask_type=attn_mask_type,
                                  scaling_factor=scaling_factor,
                                  dropout_probability=dropout_probability,
                                  is_training=is_training)
183
184
185
186

    return output


187
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
188
189
190
191
192
193
194
195
def _fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                         seed: jnp.ndarray, attn_bias_type: AttnBiasType,
                         attn_mask_type: AttnMaskType, scaling_factor: float,
                         dropout_probability: float, is_training: bool):

    output, _ = _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type,
                                              attn_mask_type, scaling_factor, dropout_probability,
                                              is_training)
196
197
198
    return output


199
200
201
def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
                                  scaling_factor, dropout_probability, is_training):
    if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
202
203
204
205
        batch, s_q, *_ = q.shape
        s_kv = kv.shape[1]
        q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
        kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
zlsh80826's avatar
zlsh80826 committed
206
    else:
207
        assert mask is not None
208
209
        mask = jnp.logical_not(mask)
        q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
210
        if attn_mask_type == AttnMaskType.PADDING_MASK:
211
212
213
214
            kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
            kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
215

216
    output, softmax_aux, rng_state = tex.fused_attn_fwd_kvpacked(
217
218
219
220
221
222
223
224
225
226
227
        q,
        kv,
        bias,
        q_actual_seqlen,
        kv_actual_seqlen,
        seed,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training)
228
229
230
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
zlsh80826's avatar
zlsh80826 committed
231
    return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
232
233


234
235
def _fused_attn_bwd_kvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor,
                                  dropout_probability, is_training, ctx, dz):
zlsh80826's avatar
zlsh80826 committed
236
    q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
237

238
    grad_q, grad_kv, grad_bias = tex.fused_attn_bwd_kvpacked(q,
239
240
241
242
243
244
245
246
247
248
249
250
251
                                                         kv,
                                                         bias,
                                                         softmax_aux,
                                                         rng_state,
                                                         output,
                                                         dz,
                                                         q_actual_seqlen,
                                                         kv_actual_seqlen,
                                                         attn_bias_type=attn_bias_type.value,
                                                         attn_mask_type=attn_mask_type.value,
                                                         scaling_factor=scaling_factor,
                                                         dropout_probability=dropout_probability,
                                                         is_training=is_training)
252
253
254
255
256

    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None

    return grad_q, grad_kv, grad_bias, None, None
257
258


259
_fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule)
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296


def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
               seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
               scaling_factor: float, dropout_probability: float, is_training: bool):
    """
    Dot product attention with the seperated query, key, value
    """

    output = _fused_attn(q,
                         k,
                         v,
                         bias,
                         mask,
                         seed,
                         attn_bias_type=attn_bias_type,
                         attn_mask_type=attn_mask_type,
                         scaling_factor=scaling_factor,
                         dropout_probability=dropout_probability,
                         is_training=is_training)

    return output


@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10))
def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
                mask: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
                attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float,
                is_training: bool):

    output, _ = _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type,
                                     scaling_factor, dropout_probability, is_training)
    return output


def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
                         dropout_probability, is_training):
297
    if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
298
299
300
301
302
        batch, s_q, *_ = q.shape
        s_kv = k.shape[1]
        q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
        kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
    else:
303
        assert mask is not None
304
305
        mask = jnp.logical_not(mask)
        q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
306
        if attn_mask_type == AttnMaskType.PADDING_MASK:
307
308
309
310
311
            kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
            kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))

312
    output, softmax_aux, rng_state = tex.fused_attn_fwd(q,
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
                                                    k,
                                                    v,
                                                    bias,
                                                    q_actual_seqlen,
                                                    kv_actual_seqlen,
                                                    seed,
                                                    attn_bias_type=attn_bias_type.value,
                                                    attn_mask_type=attn_mask_type.value,
                                                    scaling_factor=scaling_factor,
                                                    dropout_probability=dropout_probability,
                                                    is_training=is_training)
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
    return output, (q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen,
                    kv_actual_seqlen)


def _fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                         is_training, ctx, dz):
    q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx

335
    grad_q, grad_k, grad_v, grad_bias = tex.fused_attn_bwd(q,
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                                                       k,
                                                       v,
                                                       bias,
                                                       softmax_aux,
                                                       rng_state,
                                                       output,
                                                       dz,
                                                       q_actual_seqlen,
                                                       kv_actual_seqlen,
                                                       attn_bias_type=attn_bias_type.value,
                                                       attn_mask_type=attn_mask_type.value,
                                                       scaling_factor=scaling_factor,
                                                       dropout_probability=dropout_probability,
                                                       is_training=is_training)

    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None

    return grad_q, grad_k, grad_v, grad_bias, None, None


_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)