fused_attn.py 10.1 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
8
from jax.ad_checkpoint import checkpoint_name
9
10
11
12
13
import jax
import jax.numpy as jnp

from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
14
from transformer_engine_jax import NVTE_QKV_Layout
15

16
from .cpp_extensions import FusedAttnHelper
17
18
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
19
20
21
22
23
24
25
26
27
28
29
30
31
32


class AttnBiasType(Enum):
    """Attention Bias Type."""
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
    """Attention Mask Type."""
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
33
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
34
35


36
37
38
39
40
41
42
class QKVLayout(Enum):
    """QKV layout"""
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD


def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
zlsh80826's avatar
zlsh80826 committed
43
44
                                   dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q,
                                   max_seqlen_kv, head_dim):
45
46
47
    """
    To check whether the fused attention kernel is available
    """
48
    return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value,
49
50
                           attn_mask_type.value, dropout_probability, num_heads_q, num_heads_kv,
                           max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available()
51
52


53
54
55
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                    attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                    scaling_factor: float, dropout_probability: float, is_training: bool):
56
57
58
    """
    Self fused attention wrapper
    """
59
60
61
62
63
64
65
66
67
    output = _self_fused_attn(qkv,
                              bias,
                              mask,
                              seed,
                              attn_bias_type=attn_bias_type,
                              attn_mask_type=attn_mask_type,
                              scaling_factor=scaling_factor,
                              dropout_probability=dropout_probability,
                              is_training=is_training)
68
69
70
71
72

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
73
74
75
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                     attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                     scaling_factor: float, dropout_probability: float, is_training: bool):
76

77
78
79
    output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
                                          scaling_factor, dropout_probability, is_training)
    return output
80
81


82
83
84
85
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                              seed: jnp.ndarray, attn_bias_type: AttnBiasType,
                              attn_mask_type: AttnMaskType, scaling_factor: float,
                              dropout_probability: float, is_training: bool):
zlsh80826's avatar
zlsh80826 committed
86
87
    mask = jnp.logical_not(mask)
    actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
88
89
    output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
                                                         bias,
zlsh80826's avatar
zlsh80826 committed
90
                                                         actual_seqlen,
91
92
93
94
95
96
                                                         seed,
                                                         attn_bias_type=attn_bias_type.value,
                                                         attn_mask_type=attn_mask_type.value,
                                                         scaling_factor=scaling_factor,
                                                         dropout_probability=dropout_probability,
                                                         is_training=is_training)
97
98
99
    output = checkpoint_name(output, 'context')
    softmax_aux = checkpoint_name(softmax_aux, 'context')
    rng_state = checkpoint_name(rng_state, 'context')
zlsh80826's avatar
zlsh80826 committed
100
    return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
101
102


103
104
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                              is_training, ctx, dz):
zlsh80826's avatar
zlsh80826 committed
105
    qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
106

107
    grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
108
                                              bias,
109
110
111
                                              softmax_aux,
                                              rng_state,
                                              output,
112
                                              dz,
zlsh80826's avatar
zlsh80826 committed
113
                                              actual_seqlen,
114
115
116
117
118
119
                                              attn_bias_type=attn_bias_type.value,
                                              attn_mask_type=attn_mask_type.value,
                                              scaling_factor=scaling_factor,
                                              dropout_probability=dropout_probability,
                                              is_training=is_training)

120
    if attn_bias_type == AttnBiasType.NO_BIAS:
121
        grad_bias = None
122
123
124
125

    return grad_qkv, grad_bias, None, None


126
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
127
128


129
130
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                     seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
131
                     scaling_factor: float, dropout_probability: float, is_training: bool):
132
133
134
135
    """
    Cross multi-head attention wrapper
    """

136
137
    output = _cross_fused_attn(q,
                               kv,
138
                               bias,
139
140
141
142
143
144
145
                               mask,
                               seed,
                               attn_bias_type=attn_bias_type,
                               attn_mask_type=attn_mask_type,
                               scaling_factor=scaling_factor,
                               dropout_probability=dropout_probability,
                               is_training=is_training)
146
147
148
149

    return output


150
151
152
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                      seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
153
154
                      scaling_factor: float, dropout_probability: float, is_training: bool):

155
    output, _ = _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
156
                                           scaling_factor, dropout_probability, is_training)
157
158
159
    return output


160
161
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
                               scaling_factor, dropout_probability, is_training):
162

zlsh80826's avatar
zlsh80826 committed
163
164
165
166
167
168
169
    mask = jnp.logical_not(mask)
    q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
    if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
        kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]    # shape = (b,)
    else:
        # When mask is padding + causal, the actual seqlen is not the last row, use max to find it
        kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
170

171
172
173
    output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
                                                          kv,
                                                          bias,
zlsh80826's avatar
zlsh80826 committed
174
175
                                                          q_actual_seqlen,
                                                          kv_actual_seqlen,
176
177
178
179
180
181
182
                                                          seed,
                                                          attn_bias_type=attn_bias_type.value,
                                                          attn_mask_type=attn_mask_type.value,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training)

zlsh80826's avatar
zlsh80826 committed
183
    return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
184
185


186
187
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                               is_training, ctx, dz):
zlsh80826's avatar
zlsh80826 committed
188
    q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
189
190
191
192
193
194
195
196

    grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q,
                                                      kv,
                                                      bias,
                                                      softmax_aux,
                                                      rng_state,
                                                      output,
                                                      dz,
zlsh80826's avatar
zlsh80826 committed
197
198
                                                      q_actual_seqlen,
                                                      kv_actual_seqlen,
199
200
201
202
203
204
205
206
207
208
                                                      attn_bias_type=attn_bias_type.value,
                                                      attn_mask_type=attn_mask_type.value,
                                                      scaling_factor=scaling_factor,
                                                      dropout_probability=dropout_probability,
                                                      is_training=is_training)

    if attn_bias_type == AttnBiasType.NO_BIAS:
        grad_bias = None

    return grad_q, grad_kv, grad_bias, None, None
209
210


211
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)