attention.py 13.8 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
8
from jax.ad_checkpoint import checkpoint_name
9
10
11
import jax
import jax.numpy as jnp

12
13
14
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
15

16
from . import cpp_extensions as tex
17
18
19


class AttnBiasType(Enum):
20
21
22
23
24
    """
    NO_BIAS: Softmax is performed as softmax(scale * qk)
    PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias))
    POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias)
    """
25

26
27
28
29
30
31
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
32
33
34
35
36
37
    """
    NO_MASK: No attention mask is applied.
    PADDING_MASK: Indicates the presence of paddings at the end of each sequence.
    CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs.
    PADDING_CAUSAL_MASK: A combination of both causal and padding masks.
    """
38

39
40
41
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
42
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
43
44


45
46
class QKVLayout(Enum):
    """QKV layout"""
47

48
49
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
50
51
52
53
54
55
56
57
58
    BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD


def canonicalize_attn_mask_type(attn_mask_type: str):
    """Convert string attn_mask_type to AttnMaskType
    TE-JAX currently fall back to the padding version kernels for the libraries integration.
    The overhead between padding and non-padding version should be small.
    However, we will lease this limitation in the near feature.
    """
59
    match attn_mask_type:
60
        case "no_mask":
61
            return AttnMaskType.NO_MASK
62
        case "padding":
63
            return AttnMaskType.PADDING_MASK
64
        case "causal":
65
            return AttnMaskType.CAUSAL_MASK
66
        case "padding_causal" | "causal_padding":
67
            return AttnMaskType.PADDING_CAUSAL_MASK
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    raise ValueError(
        f"Unsupported {attn_mask_type=}, supported attn_mask_type="
        "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
    )


def is_fused_attn_kernel_available(
    q_dtype,
    kv_dtype,
    qkv_layout,
    attn_bias_type,
    attn_mask_type,
    dropout_probability,
    q_num_heads,
    kv_num_heads,
    q_max_seqlen,
    kv_max_seqlen,
    head_dim,
):
87
    """
88
    To check whether the fused attention kernel is supported
89
    """
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    return tex.FusedAttnHelper(
        q_dtype,
        kv_dtype,
        qkv_layout.value,
        attn_bias_type.value,
        attn_mask_type.value,
        dropout_probability,
        q_num_heads,
        kv_num_heads,
        q_max_seqlen,
        kv_max_seqlen,
        head_dim,
    ).is_fused_attn_kernel_available()


def fused_attn_qkvpacked(
    qkv: jnp.ndarray,
    bias: jnp.ndarray | None,
    mask: jnp.ndarray,
    seed: jnp.ndarray | None,
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):
116
    """
117
    Fused attention with the qkvpacked inputs
118
    """
119
120
121
122
123
124
125
126
127
128
129
    output = _fused_attn_qkvpacked(
        qkv,
        bias,
        mask,
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
    )
130
131
132
133
134

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def _fused_attn_qkvpacked(
    qkv: jnp.ndarray,
    bias: jnp.ndarray | None,
    mask: jnp.ndarray,
    seed: jnp.ndarray | None,
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):

    output, _ = _fused_attn_fwd_qkvpacked_rule(
        qkv,
        bias,
        mask,
        seed,
        attn_bias_type,
        attn_mask_type,
        scaling_factor,
        dropout_probability,
        is_training,
    )
158
    return output
159
160


161
162
163
164
165
166
167
168
169
170
171
def _fused_attn_fwd_qkvpacked_rule(
    qkv: jnp.ndarray,
    bias: jnp.ndarray | None,
    mask: jnp.ndarray,
    seed: jnp.ndarray | None,
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):
172
    if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
173
174
175
        batch, seqlen, *_ = qkv.shape
        actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
    else:
176
        assert mask is not None
177
        mask = jnp.logical_not(mask)
178
        actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]  # shape = (b,)
179
    output, softmax_aux, rng_state = tex.fused_attn_fwd_qkvpacked(
180
181
182
183
184
185
186
187
        qkv,
        bias,
        actual_seqlen,
        seed,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
188
189
190
191
192
        is_training=is_training,
    )
    output = checkpoint_name(output, "context")
    softmax_aux = checkpoint_name(softmax_aux, "context")
    rng_state = checkpoint_name(rng_state, "context")
zlsh80826's avatar
zlsh80826 committed
193
    return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
194
195


196
197
198
def _fused_attn_bwd_qkvpacked_rule(
    attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz
):
zlsh80826's avatar
zlsh80826 committed
199
    qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
200

201
202
203
204
205
206
207
208
209
210
211
212
213
214
    grad_qkv, grad_bias = tex.fused_attn_bwd_qkvpacked(
        qkv,
        bias,
        softmax_aux,
        rng_state,
        output,
        dz,
        actual_seqlen,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
    )
215

216
    if attn_bias_type == AttnBiasType.NO_BIAS:
217
        grad_bias = None
218
219
220
221

    return grad_qkv, grad_bias, None, None


222
_fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule)
223
224


225
226
227
228
229
230
231
232
233
234
235
236
def fused_attn_kvpacked(
    q: jnp.ndarray,
    kv: jnp.ndarray,
    bias: jnp.ndarray,
    mask: jnp.ndarray,
    seed: jnp.ndarray,
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):
237
    """
238
    Fused attention with the kvpacked inputs
239
240
    """

241
242
243
244
245
246
247
248
249
250
251
252
    output = _fused_attn_kvpacked(
        q,
        kv,
        bias,
        mask,
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
    )
253
254
255
256

    return output


257
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def _fused_attn_kvpacked(
    q: jnp.ndarray,
    kv: jnp.ndarray,
    bias: jnp.ndarray,
    mask: jnp.ndarray,
    seed: jnp.ndarray,
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):

    output, _ = _fused_attn_fwd_kvpacked_rule(
        q,
        kv,
        bias,
        mask,
        seed,
        attn_bias_type,
        attn_mask_type,
        scaling_factor,
        dropout_probability,
        is_training,
    )
283
284
285
    return output


286
287
288
289
290
291
292
293
294
295
296
297
def _fused_attn_fwd_kvpacked_rule(
    q,
    kv,
    bias,
    mask,
    seed,
    attn_bias_type,
    attn_mask_type,
    scaling_factor,
    dropout_probability,
    is_training,
):
298
    if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
299
300
301
302
        batch, s_q, *_ = q.shape
        s_kv = kv.shape[1]
        q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
        kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
zlsh80826's avatar
zlsh80826 committed
303
    else:
304
        assert mask is not None
305
        mask = jnp.logical_not(mask)
306
        q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]  # shape = (b,)
307
        if attn_mask_type == AttnMaskType.PADDING_MASK:
308
            kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]  # shape = (b,)
309
310
311
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
            kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
312

313
    output, softmax_aux, rng_state = tex.fused_attn_fwd_kvpacked(
314
315
316
317
318
319
320
321
322
323
        q,
        kv,
        bias,
        q_actual_seqlen,
        kv_actual_seqlen,
        seed,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
324
325
326
327
328
        is_training=is_training,
    )
    output = checkpoint_name(output, "context")
    softmax_aux = checkpoint_name(softmax_aux, "context")
    rng_state = checkpoint_name(rng_state, "context")
zlsh80826's avatar
zlsh80826 committed
329
    return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
330
331


332
333
334
def _fused_attn_bwd_kvpacked_rule(
    attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz
):
zlsh80826's avatar
zlsh80826 committed
335
    q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
336

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    grad_q, grad_kv, grad_bias = tex.fused_attn_bwd_kvpacked(
        q,
        kv,
        bias,
        softmax_aux,
        rng_state,
        output,
        dz,
        q_actual_seqlen,
        kv_actual_seqlen,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
    )
353
354
355
356
357

    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None

    return grad_q, grad_kv, grad_bias, None, None
358
359


360
_fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule)
361
362


363
364
365
366
367
368
369
370
371
372
373
374
375
def fused_attn(
    q: jnp.ndarray,
    k: jnp.ndarray,
    v: jnp.ndarray,
    bias: jnp.ndarray,
    mask: jnp.ndarray,
    seed: jnp.ndarray,
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):
376
377
378
379
    """
    Dot product attention with the seperated query, key, value
    """

380
381
382
383
384
385
386
387
388
389
390
391
392
    output = _fused_attn(
        q,
        k,
        v,
        bias,
        mask,
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
    )
393
394
395
396
397

    return output


@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10))
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
def _fused_attn(
    q: jnp.ndarray,
    k: jnp.ndarray,
    v: jnp.ndarray,
    bias: jnp.ndarray,
    mask: jnp.ndarray,
    seed: jnp.ndarray,
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
):

    output, _ = _fused_attn_fwd_rule(
        q,
        k,
        v,
        bias,
        mask,
        seed,
        attn_bias_type,
        attn_mask_type,
        scaling_factor,
        dropout_probability,
        is_training,
    )
425
426
427
    return output


428
429
430
431
432
433
434
435
436
437
438
439
440
def _fused_attn_fwd_rule(
    q,
    k,
    v,
    bias,
    mask,
    seed,
    attn_bias_type,
    attn_mask_type,
    scaling_factor,
    dropout_probability,
    is_training,
):
441
    if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
442
443
444
445
446
        batch, s_q, *_ = q.shape
        s_kv = k.shape[1]
        q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
        kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
    else:
447
        assert mask is not None
448
        mask = jnp.logical_not(mask)
449
        q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]  # shape = (b,)
450
        if attn_mask_type == AttnMaskType.PADDING_MASK:
451
            kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]  # shape = (b,)
452
453
454
455
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
            kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    output, softmax_aux, rng_state = tex.fused_attn_fwd(
        q,
        k,
        v,
        bias,
        q_actual_seqlen,
        kv_actual_seqlen,
        seed,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
    )
    output = checkpoint_name(output, "context")
    softmax_aux = checkpoint_name(softmax_aux, "context")
    rng_state = checkpoint_name(rng_state, "context")
    return output, (
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        q_actual_seqlen,
        kv_actual_seqlen,
    )


def _fused_attn_bwd_rule(
    attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz
):
489
490
    q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx

491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    grad_q, grad_k, grad_v, grad_bias = tex.fused_attn_bwd(
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        dz,
        q_actual_seqlen,
        kv_actual_seqlen,
        attn_bias_type=attn_bias_type.value,
        attn_mask_type=attn_mask_type.value,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training,
    )
508
509
510
511
512
513
514
515

    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None

    return grad_q, grad_k, grad_v, grad_bias, None, None


_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)