fused_attn.py 16.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
8
from jax.ad_checkpoint import checkpoint_name
9
10
11
12
13
import jax
import jax.numpy as jnp

from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
14
from transformer_engine_jax import NVTE_QKV_Layout
15

16
from .cpp_extensions import FusedAttnHelper
17
18
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
19
from .cpp_extensions import fused_attn_fwd, fused_attn_bwd
20
21
22
23
24
25
26
27
28
29
30
31
32
33


class AttnBiasType(Enum):
    """Attention Bias Type."""
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
    """Attention Mask Type."""
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
34
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
35
36


37
38
39
40
class QKVLayout(Enum):
    """QKV layout"""
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD


def canonicalize_attn_mask_type(attn_mask_type: str):
    """Convert string attn_mask_type to AttnMaskType
    TE-JAX currently fall back to the padding version kernels for the libraries integration.
    The overhead between padding and non-padding version should be small.
    However, we will lease this limitation in the near feature.
    """
    if attn_mask_type in ['causal', 'padding_causal']:
        return AttnMaskType.PADDING_CAUSAL_MASK
    if attn_mask_type in ['no_mask', 'padding']:
        return AttnMaskType.PADDING_MASK
    raise ValueError(f"Unsupported {attn_mask_type=}, "
                     "supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}")
56
57
58


def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
zlsh80826's avatar
zlsh80826 committed
59
60
                                   dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q,
                                   max_seqlen_kv, head_dim):
61
62
63
    """
    To check whether the fused attention kernel is available
    """
64
    return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value,
65
66
                           attn_mask_type.value, dropout_probability, num_heads_q, num_heads_kv,
                           max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available()
67
68


69
70
71
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                    attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                    scaling_factor: float, dropout_probability: float, is_training: bool):
72
73
74
    """
    Self fused attention wrapper
    """
75
76
77
78
79
80
81
82
83
    output = _self_fused_attn(qkv,
                              bias,
                              mask,
                              seed,
                              attn_bias_type=attn_bias_type,
                              attn_mask_type=attn_mask_type,
                              scaling_factor=scaling_factor,
                              dropout_probability=dropout_probability,
                              is_training=is_training)
84
85
86
87
88

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
89
90
91
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                     attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                     scaling_factor: float, dropout_probability: float, is_training: bool):
92

93
94
95
    output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
                                          scaling_factor, dropout_probability, is_training)
    return output
96
97


98
99
100
101
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                              seed: jnp.ndarray, attn_bias_type: AttnBiasType,
                              attn_mask_type: AttnMaskType, scaling_factor: float,
                              dropout_probability: float, is_training: bool):
102
103
104
105
106
107
    if mask is None:
        batch, seqlen, *_ = qkv.shape
        actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
    else:
        mask = jnp.logical_not(mask)
        actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
108
109
    output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
                                                         bias,
zlsh80826's avatar
zlsh80826 committed
110
                                                         actual_seqlen,
111
112
113
114
115
116
                                                         seed,
                                                         attn_bias_type=attn_bias_type.value,
                                                         attn_mask_type=attn_mask_type.value,
                                                         scaling_factor=scaling_factor,
                                                         dropout_probability=dropout_probability,
                                                         is_training=is_training)
117
118
119
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
zlsh80826's avatar
zlsh80826 committed
120
    return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
121
122


123
124
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                              is_training, ctx, dz):
zlsh80826's avatar
zlsh80826 committed
125
    qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
126

127
    grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
128
                                              bias,
129
130
131
                                              softmax_aux,
                                              rng_state,
                                              output,
132
                                              dz,
zlsh80826's avatar
zlsh80826 committed
133
                                              actual_seqlen,
134
135
136
137
138
139
                                              attn_bias_type=attn_bias_type.value,
                                              attn_mask_type=attn_mask_type.value,
                                              scaling_factor=scaling_factor,
                                              dropout_probability=dropout_probability,
                                              is_training=is_training)

140
    if attn_bias_type == AttnBiasType.NO_BIAS:
141
        grad_bias = None
142
143
144
145

    return grad_qkv, grad_bias, None, None


146
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
147
148


149
150
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                     seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
151
                     scaling_factor: float, dropout_probability: float, is_training: bool):
152
153
154
155
    """
    Cross multi-head attention wrapper
    """

156
157
    output = _cross_fused_attn(q,
                               kv,
158
                               bias,
159
160
161
162
163
164
165
                               mask,
                               seed,
                               attn_bias_type=attn_bias_type,
                               attn_mask_type=attn_mask_type,
                               scaling_factor=scaling_factor,
                               dropout_probability=dropout_probability,
                               is_training=is_training)
166
167
168
169

    return output


170
171
172
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                      seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
173
174
                      scaling_factor: float, dropout_probability: float, is_training: bool):

175
    output, _ = _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
176
                                           scaling_factor, dropout_probability, is_training)
177
178
179
    return output


180
181
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
                               scaling_factor, dropout_probability, is_training):
182
183
184
185
186
    if mask is None:
        batch, s_q, *_ = q.shape
        s_kv = kv.shape[1]
        q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
        kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
zlsh80826's avatar
zlsh80826 committed
187
    else:
188
189
190
191
192
193
194
        mask = jnp.logical_not(mask)
        q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
            kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
            kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
195

196
197
198
    output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
                                                          kv,
                                                          bias,
zlsh80826's avatar
zlsh80826 committed
199
200
                                                          q_actual_seqlen,
                                                          kv_actual_seqlen,
201
202
203
204
205
206
                                                          seed,
                                                          attn_bias_type=attn_bias_type.value,
                                                          attn_mask_type=attn_mask_type.value,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training)
207
208
209
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
zlsh80826's avatar
zlsh80826 committed
210
    return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
211
212


213
214
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                               is_training, ctx, dz):
zlsh80826's avatar
zlsh80826 committed
215
    q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
216
217
218
219
220
221
222
223

    grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q,
                                                      kv,
                                                      bias,
                                                      softmax_aux,
                                                      rng_state,
                                                      output,
                                                      dz,
zlsh80826's avatar
zlsh80826 committed
224
225
                                                      q_actual_seqlen,
                                                      kv_actual_seqlen,
226
227
228
229
230
231
232
233
234
235
                                                      attn_bias_type=attn_bias_type.value,
                                                      attn_mask_type=attn_mask_type.value,
                                                      scaling_factor=scaling_factor,
                                                      dropout_probability=dropout_probability,
                                                      is_training=is_training)

    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None

    return grad_q, grad_kv, grad_bias, None, None
236
237


238
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335


def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
               seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
               scaling_factor: float, dropout_probability: float, is_training: bool):
    """
    Dot product attention with the seperated query, key, value
    """

    output = _fused_attn(q,
                         k,
                         v,
                         bias,
                         mask,
                         seed,
                         attn_bias_type=attn_bias_type,
                         attn_mask_type=attn_mask_type,
                         scaling_factor=scaling_factor,
                         dropout_probability=dropout_probability,
                         is_training=is_training)

    return output


@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10))
def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
                mask: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
                attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float,
                is_training: bool):

    output, _ = _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type,
                                     scaling_factor, dropout_probability, is_training)
    return output


def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
                         dropout_probability, is_training):
    if mask is None:
        batch, s_q, *_ = q.shape
        s_kv = k.shape[1]
        q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
        kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
    else:
        mask = jnp.logical_not(mask)
        q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
            kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
            kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))

    output, softmax_aux, rng_state = fused_attn_fwd(q,
                                                    k,
                                                    v,
                                                    bias,
                                                    q_actual_seqlen,
                                                    kv_actual_seqlen,
                                                    seed,
                                                    attn_bias_type=attn_bias_type.value,
                                                    attn_mask_type=attn_mask_type.value,
                                                    scaling_factor=scaling_factor,
                                                    dropout_probability=dropout_probability,
                                                    is_training=is_training)
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
    return output, (q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen,
                    kv_actual_seqlen)


def _fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                         is_training, ctx, dz):
    q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx

    grad_q, grad_k, grad_v, grad_bias = fused_attn_bwd(q,
                                                       k,
                                                       v,
                                                       bias,
                                                       softmax_aux,
                                                       rng_state,
                                                       output,
                                                       dz,
                                                       q_actual_seqlen,
                                                       kv_actual_seqlen,
                                                       attn_bias_type=attn_bias_type.value,
                                                       attn_mask_type=attn_mask_type.value,
                                                       scaling_factor=scaling_factor,
                                                       dropout_probability=dropout_probability,
                                                       is_training=is_training)

    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None

    return grad_q, grad_k, grad_v, grad_bias, None, None


_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)