fused_attn.py 8.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
import jax
import jax.numpy as jnp

from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
13
from transformer_engine_jax import NVTE_QKV_Layout
14

15
from .cpp_extensions import FusedAttnHelper
16
17
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
18
19
20
21
22
23
24
25
26
27
28
29
30
31


class AttnBiasType(Enum):
    """Attention Bias Type."""
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
    """Attention Mask Type."""
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
32
    PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
33
34


35
36
37
38
39
40
41
class QKVLayout(Enum):
    """QKV layout"""
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD


def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
42
43
44
45
                                   dropout_probability, max_seqlen_q, max_seqlen_kv, head_dim):
    """
    To check whether the fused attention kernel is available
    """
46
47
    return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value,
                           attn_mask_type.value, dropout_probability, max_seqlen_q, max_seqlen_kv,
48
49
50
                           head_dim).is_fused_attn_kernel_available()


51
52
53
def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                    attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                    scaling_factor: float, dropout_probability: float, is_training: bool):
54
55
56
    """
    Self fused attention wrapper
    """
57
58
59
60
61
62
63
64
65
    output = _self_fused_attn(qkv,
                              bias,
                              mask,
                              seed,
                              attn_bias_type=attn_bias_type,
                              attn_mask_type=attn_mask_type,
                              scaling_factor=scaling_factor,
                              dropout_probability=dropout_probability,
                              is_training=is_training)
66
67
68
69
70

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
71
72
73
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                     attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                     scaling_factor: float, dropout_probability: float, is_training: bool):
74

75
76
77
    output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
                                          scaling_factor, dropout_probability, is_training)
    return output
78
79


80
81
82
83
84
def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
                              seed: jnp.ndarray, attn_bias_type: AttnBiasType,
                              attn_mask_type: AttnMaskType, scaling_factor: float,
                              dropout_probability: float, is_training: bool):
    squeezed_mask = mask[:, :, :, 0]
85
86
    output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
                                                         bias,
87
                                                         squeezed_mask,
88
89
90
91
92
93
                                                         seed,
                                                         attn_bias_type=attn_bias_type.value,
                                                         attn_mask_type=attn_mask_type.value,
                                                         scaling_factor=scaling_factor,
                                                         dropout_probability=dropout_probability,
                                                         is_training=is_training)
94
    return output, (qkv, softmax_aux, rng_state, output, squeezed_mask)
95
96


97
98
99
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                              is_training, ctx, dz):
    qkv, softmax_aux, rng_state, output, squeezed_mask = ctx
100

101
102
103
104
    grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
                                              softmax_aux,
                                              rng_state,
                                              output,
105
106
                                              dz,
                                              squeezed_mask,
107
108
109
110
111
112
                                              attn_bias_type=attn_bias_type.value,
                                              attn_mask_type=attn_mask_type.value,
                                              scaling_factor=scaling_factor,
                                              dropout_probability=dropout_probability,
                                              is_training=is_training)

113
    if attn_bias_type == AttnBiasType.NO_BIAS:
114
        grad_bias = None
115
116
117
118

    return grad_qkv, grad_bias, None, None


119
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
120
121


122
123
124
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                     attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                     scaling_factor: float, dropout_probability: float, is_training: bool):
125
126
127
128
    """
    Cross multi-head attention wrapper
    """

129
130
131
132
133
134
135
136
137
    output = _cross_fused_attn(q,
                               kv,
                               mask,
                               seed,
                               attn_bias_type=attn_bias_type,
                               attn_mask_type=attn_mask_type,
                               scaling_factor=scaling_factor,
                               dropout_probability=dropout_probability,
                               is_training=is_training)
138
139
140
141
142

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
143
144
145
146
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                      attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                      scaling_factor: float, dropout_probability: float, is_training: bool):

147
148
    output, _ = _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type,
                                           scaling_factor, dropout_probability, is_training)
149
150
151
    return output


152
153
def _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
                               dropout_probability, is_training):
154

155
156
    q_squeezed_mask = mask[:, :, :, 0]
    kv_squeezed_mask = mask[:, :, 0, :]
157

158
159
    output, softmax_aux = cross_fused_attn_fwd(q,
                                               kv,
160
161
                                               q_squeezed_mask,
                                               kv_squeezed_mask,
162
163
164
165
166
167
                                               seed,
                                               attn_bias_type=attn_bias_type.value,
                                               attn_mask_type=attn_mask_type.value,
                                               scaling_factor=scaling_factor,
                                               dropout_probability=dropout_probability,
                                               is_training=is_training)
168
    return output, (softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask)
169
170


171
172
173
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                               is_training, ctx, dz):
    softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask = ctx
174

175
176
177
    grad_q, grad_kv = cross_fused_attn_bwd(q,
                                           kv,
                                           softmax_aux,
178
179
180
                                           dz,
                                           q_squeezed_mask,
                                           kv_squeezed_mask,
181
182
183
184
185
                                           attn_bias_type=attn_bias_type.value,
                                           attn_mask_type=attn_mask_type.value,
                                           scaling_factor=scaling_factor,
                                           dropout_probability=dropout_probability,
                                           is_training=is_training)
186
187
188
189

    return grad_q, grad_kv, None, None


190
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)