fused_attn.py 16.4 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
8
from jax.ad_checkpoint import checkpoint_name
9
10
11
12
13
import jax
import jax.numpy as jnp

from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
14
from transformer_engine_jax import NVTE_QKV_Layout
15

16
from .cpp_extensions import FusedAttnHelper
17
18
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
19
from .cpp_extensions import fused_attn_fwd, fused_attn_bwd
20
21
22
23
24
25
26
27
28
29
30
31
32
33


class AttnBiasType(Enum):
    """Attention Bias Type."""
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
    """Attention Mask Type."""
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
34
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
35
36


37
38
39
40
class QKVLayout(Enum):
    """QKV layout"""
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD


def canonicalize_attn_mask_type(attn_mask_type: str):
    """Convert string attn_mask_type to AttnMaskType
    TE-JAX currently fall back to the padding version kernels for the libraries integration.
    The overhead between padding and non-padding version should be small.
    However, we will lease this limitation in the near feature.
    """
    if attn_mask_type in ['causal', 'padding_causal']:
        return AttnMaskType.PADDING_CAUSAL_MASK
    if attn_mask_type in ['no_mask', 'padding']:
        return AttnMaskType.PADDING_MASK
    raise ValueError(f"Unsupported {attn_mask_type=}, "
                     "supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}")
56
57
58


def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
zlsh80826's avatar
zlsh80826 committed
59
60
                                   dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q,
                                   max_seqlen_kv, head_dim):
61
62
63
    """
    To check whether the fused attention kernel is available
    """
64
    return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value,
65
66
                           attn_mask_type.value, dropout_probability, num_heads_q, num_heads_kv,
                           max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available()
67
68


69
70
71
72
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
                    seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
                    attn_mask_type: AttnMaskType, scaling_factor: float,
                    dropout_probability: float, is_training: bool):
73
74
75
    """
    Self fused attention wrapper
    """
76
77
78
79
80
81
82
83
84
    output = _self_fused_attn(qkv,
                              bias,
                              mask,
                              seed,
                              attn_bias_type=attn_bias_type,
                              attn_mask_type=attn_mask_type,
                              scaling_factor=scaling_factor,
                              dropout_probability=dropout_probability,
                              is_training=is_training)
85
86
87
88
89

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
90
91
92
93
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray,
                     seed: jnp.ndarray | None, attn_bias_type: AttnBiasType,
                     attn_mask_type: AttnMaskType, scaling_factor: float,
                     dropout_probability: float, is_training: bool):
94

95
96
97
    output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
                                          scaling_factor, dropout_probability, is_training)
    return output
98
99


100
101
102
103
104
105
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None,
                              mask: jnp.ndarray, seed: jnp.ndarray | None,
                              attn_bias_type: AttnBiasType,
                              attn_mask_type: AttnMaskType,
                              scaling_factor: float, dropout_probability: float,
                              is_training: bool):
106
107
108
109
110
111
    if mask is None:
        batch, seqlen, *_ = qkv.shape
        actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
    else:
        mask = jnp.logical_not(mask)
        actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
112
113
    output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
                                                         bias,
zlsh80826's avatar
zlsh80826 committed
114
                                                         actual_seqlen,
115
116
117
118
119
120
                                                         seed,
                                                         attn_bias_type=attn_bias_type.value,
                                                         attn_mask_type=attn_mask_type.value,
                                                         scaling_factor=scaling_factor,
                                                         dropout_probability=dropout_probability,
                                                         is_training=is_training)
121
122
123
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
zlsh80826's avatar
zlsh80826 committed
124
    return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
125
126


127
128
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                              is_training, ctx, dz):
zlsh80826's avatar
zlsh80826 committed
129
    qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
130

131
    grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
132
                                              bias,
133
134
135
                                              softmax_aux,
                                              rng_state,
                                              output,
136
                                              dz,
zlsh80826's avatar
zlsh80826 committed
137
                                              actual_seqlen,
138
139
140
141
142
143
                                              attn_bias_type=attn_bias_type.value,
                                              attn_mask_type=attn_mask_type.value,
                                              scaling_factor=scaling_factor,
                                              dropout_probability=dropout_probability,
                                              is_training=is_training)

144
    if attn_bias_type == AttnBiasType.NO_BIAS:
145
        grad_bias = None
146
147
148
149

    return grad_qkv, grad_bias, None, None


150
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
151
152


153
154
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                     seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
155
                     scaling_factor: float, dropout_probability: float, is_training: bool):
156
157
158
159
    """
    Cross multi-head attention wrapper
    """

160
161
    output = _cross_fused_attn(q,
                               kv,
162
                               bias,
163
164
165
166
167
168
169
                               mask,
                               seed,
                               attn_bias_type=attn_bias_type,
                               attn_mask_type=attn_mask_type,
                               scaling_factor=scaling_factor,
                               dropout_probability=dropout_probability,
                               is_training=is_training)
170
171
172
173

    return output


174
175
176
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                      seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
177
178
                      scaling_factor: float, dropout_probability: float, is_training: bool):

179
    output, _ = _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
180
                                           scaling_factor, dropout_probability, is_training)
181
182
183
    return output


184
185
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
                               scaling_factor, dropout_probability, is_training):
186
187
188
189
190
    if mask is None:
        batch, s_q, *_ = q.shape
        s_kv = kv.shape[1]
        q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
        kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
zlsh80826's avatar
zlsh80826 committed
191
    else:
192
193
194
195
196
197
198
        mask = jnp.logical_not(mask)
        q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
            kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
            kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
199

200
201
202
    output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
                                                          kv,
                                                          bias,
zlsh80826's avatar
zlsh80826 committed
203
204
                                                          q_actual_seqlen,
                                                          kv_actual_seqlen,
205
206
207
208
209
210
                                                          seed,
                                                          attn_bias_type=attn_bias_type.value,
                                                          attn_mask_type=attn_mask_type.value,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training)
211
212
213
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
zlsh80826's avatar
zlsh80826 committed
214
    return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
215
216


217
218
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                               is_training, ctx, dz):
zlsh80826's avatar
zlsh80826 committed
219
    q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
220
221
222
223
224
225
226
227

    grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q,
                                                      kv,
                                                      bias,
                                                      softmax_aux,
                                                      rng_state,
                                                      output,
                                                      dz,
zlsh80826's avatar
zlsh80826 committed
228
229
                                                      q_actual_seqlen,
                                                      kv_actual_seqlen,
230
231
232
233
234
235
236
237
238
239
                                                      attn_bias_type=attn_bias_type.value,
                                                      attn_mask_type=attn_mask_type.value,
                                                      scaling_factor=scaling_factor,
                                                      dropout_probability=dropout_probability,
                                                      is_training=is_training)

    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None

    return grad_q, grad_kv, grad_bias, None, None
240
241


242
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339


def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
               seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
               scaling_factor: float, dropout_probability: float, is_training: bool):
    """
    Dot product attention with the seperated query, key, value
    """

    output = _fused_attn(q,
                         k,
                         v,
                         bias,
                         mask,
                         seed,
                         attn_bias_type=attn_bias_type,
                         attn_mask_type=attn_mask_type,
                         scaling_factor=scaling_factor,
                         dropout_probability=dropout_probability,
                         is_training=is_training)

    return output


@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10))
def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
                mask: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
                attn_mask_type: AttnMaskType, scaling_factor: float, dropout_probability: float,
                is_training: bool):

    output, _ = _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type,
                                     scaling_factor, dropout_probability, is_training)
    return output


def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
                         dropout_probability, is_training):
    if mask is None:
        batch, s_q, *_ = q.shape
        s_kv = k.shape[1]
        q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
        kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
    else:
        mask = jnp.logical_not(mask)
        q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
            kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
        else:
            # When mask is causal, the actual seqlen is not the last row, use max to find it
            kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))

    output, softmax_aux, rng_state = fused_attn_fwd(q,
                                                    k,
                                                    v,
                                                    bias,
                                                    q_actual_seqlen,
                                                    kv_actual_seqlen,
                                                    seed,
                                                    attn_bias_type=attn_bias_type.value,
                                                    attn_mask_type=attn_mask_type.value,
                                                    scaling_factor=scaling_factor,
                                                    dropout_probability=dropout_probability,
                                                    is_training=is_training)
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
    return output, (q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen,
                    kv_actual_seqlen)


def _fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                         is_training, ctx, dz):
    q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx

    grad_q, grad_k, grad_v, grad_bias = fused_attn_bwd(q,
                                                       k,
                                                       v,
                                                       bias,
                                                       softmax_aux,
                                                       rng_state,
                                                       output,
                                                       dz,
                                                       q_actual_seqlen,
                                                       kv_actual_seqlen,
                                                       attn_bias_type=attn_bias_type.value,
                                                       attn_mask_type=attn_mask_type.value,
                                                       scaling_factor=scaling_factor,
                                                       dropout_probability=dropout_probability,
                                                       is_training=is_training)

    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None

    return grad_q, grad_k, grad_v, grad_bias, None, None


_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)