fused_attn.py 12.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
import jax
import jax.numpy as jnp

import transformer_engine_jax
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type

15
16
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
17
18
from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType
19
from .sharding import xmap_runner, extend_fsdp_sharding_meta
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)


def is_fused_attn_kernel_available():
    """
    To check whether the fused attention kernel is available
    """
    return transformer_engine_jax.is_fused_attn_kernel_available()


class AttnBiasType(Enum):
    """Attention Bias Type."""
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
    """Attention Mask Type."""
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK


def self_fused_attn(qkv: jnp.ndarray,
                    bias: jnp.ndarray,
                    mask: jnp.ndarray,
49
                    seed: jnp.ndarray,
50
51
52
53
54
55
56
57
58
59
                    attn_bias_type: AttnBiasType,
                    attn_mask_type: AttnMaskType,
                    scaling_factor: float,
                    dropout_probability: float,
                    is_training: bool,
                    sharding_type: ShardingType = ShardingType.SINGLE):
    """
    Self fused attention wrapper
    """
    assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
60
        "self_fused_attn does not support row-split tensor parallelism currently."
61
62

    if sharding_type is ShardingType.SINGLE:
63
64
65
66
67
68
69
70
71
        output = _self_fused_attn(qkv,
                                  bias,
                                  mask,
                                  seed,
                                  attn_bias_type=attn_bias_type,
                                  attn_mask_type=attn_mask_type,
                                  scaling_factor=scaling_factor,
                                  dropout_probability=dropout_probability,
                                  is_training=is_training)
72
73
74
75
    else:
        dp_axis_name = "batch"
        tp_axis_name = "model"

76
        inputs = [qkv, bias, mask, seed]
77
78
79
80
        batch, seqlen, _, num_head, head_dim = qkv.shape
        output_shape = [batch, seqlen, num_head, head_dim]
        sharding_meta = get_fused_attn_sharding_meta(
            sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
81
82
            dp_dims=([0, None, 0, 0], [0]),
            tp_dims=([3, 1, None, 0], [2]),
83
84
            dp_axis_name=dp_axis_name,
            tp_axis_name=tp_axis_name)
85
        sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
86
87
88
89
90

        inputs_ = tuple(
            jnp.reshape(x, new_shape) if x is not None else None
            for x, new_shape in zip(inputs, sharding_meta.input_shapes))

91
92
93
94
95
96
        partial_self_fused_attn = partial(_self_fused_attn,
                                          attn_bias_type=attn_bias_type,
                                          attn_mask_type=attn_mask_type,
                                          scaling_factor=scaling_factor,
                                          dropout_probability=dropout_probability,
                                          is_training=is_training)
97

98
        output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes,
99
                              sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
100

101
        output = jnp.reshape(output_, sharding_meta.output_shapes)
102
103
104
105
106

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
107
108
109
110
111
112
113
114
115
116
117
118
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                     attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                     scaling_factor: float, dropout_probability: float, is_training: bool):
    output, _ = _self_fused_attn_fwd(qkv,
                                     bias,
                                     mask,
                                     seed,
                                     attn_bias_type=attn_bias_type,
                                     attn_mask_type=attn_mask_type,
                                     scaling_factor=scaling_factor,
                                     dropout_probability=dropout_probability,
                                     is_training=is_training)
119
120
121
    return output


122
123
def _self_fused_attn_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
                         dropout_probability, is_training):
124
125
126
127
128

    seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
    cu_seqlen = jnp.cumsum(seqlen)
    cu_seqlen = jnp.hstack((0, cu_seqlen))

129
130
131
132
133
134
135
136
137
138
    output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
                                                         bias,
                                                         cu_seqlen,
                                                         seed,
                                                         attn_bias_type=attn_bias_type.value,
                                                         attn_mask_type=attn_mask_type.value,
                                                         scaling_factor=scaling_factor,
                                                         dropout_probability=dropout_probability,
                                                         is_training=is_training)
    return output, (qkv, softmax_aux, rng_state, output, cu_seqlen)
139
140


141
142
143
def _self_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                         is_training, ctx, grad):
    qkv, softmax_aux, rng_state, output, cu_seqlen = ctx
144
145
146

    doutput = grad

147
148
149
150
151
152
153
154
155
156
157
158
159
160
    grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
                                              softmax_aux,
                                              rng_state,
                                              output,
                                              doutput,
                                              cu_seqlen,
                                              attn_bias_type=attn_bias_type.value,
                                              attn_mask_type=attn_mask_type.value,
                                              scaling_factor=scaling_factor,
                                              dropout_probability=dropout_probability,
                                              is_training=is_training)

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        grad_bias = None
161
162
163
164

    return grad_qkv, grad_bias, None, None


165
_self_fused_attn.defvjp(_self_fused_attn_fwd, _self_fused_attn_bwd)
166
167
168
169
170


def cross_fused_attn(q: jnp.ndarray,
                     kv: jnp.ndarray,
                     mask: jnp.ndarray,
171
                     seed: jnp.ndarray,
172
173
174
175
176
177
178
179
180
181
                     attn_bias_type: AttnBiasType,
                     attn_mask_type: AttnMaskType,
                     scaling_factor: float,
                     dropout_probability: float,
                     is_training: bool,
                     sharding_type: ShardingType = ShardingType.SINGLE):
    """
    Cross multi-head attention wrapper
    """
    assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
182
        "cross_fused_attn does not support row-split tensor parallelism currently."
183
184

    if sharding_type is ShardingType.SINGLE:
185
186
187
188
189
190
191
192
193
        output = _cross_fused_attn(q,
                                   kv,
                                   mask,
                                   seed,
                                   attn_bias_type=attn_bias_type,
                                   attn_mask_type=attn_mask_type,
                                   scaling_factor=scaling_factor,
                                   dropout_probability=dropout_probability,
                                   is_training=is_training)
194
195
196
197
    else:
        dp_axis_name = "batch"
        tp_axis_name = "model"

198
        inputs = [q, kv, mask, seed]
199
200
201
202
203
204
205
        output_shape = q.shape
        sharding_meta = get_fused_attn_sharding_meta(
            sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
            dp_dims=([0, 0, 0, None], [0]),
            tp_dims=([2, 3, None, None], [2]),
            dp_axis_name=dp_axis_name,
            tp_axis_name=tp_axis_name)
206
        sharding_meta = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
207
208
209
210
211

        inputs_ = tuple(
            jnp.reshape(x, new_shape) if x is not None else None
            for x, new_shape in zip(inputs, sharding_meta.input_shapes))

212
213
214
215
216
217
        partial_cross_fused_attn = partial(_cross_fused_attn,
                                           attn_bias_type=attn_bias_type,
                                           attn_mask_type=attn_mask_type,
                                           scaling_factor=scaling_factor,
                                           dropout_probability=dropout_probability,
                                           is_training=is_training)
218

219
        output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes,
220
                              sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
221

222
        output = jnp.reshape(output_, sharding_meta.output_shapes)
223
224
225
226
227

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
228
229
230
231
232
233
234
235
236
237
238
239
240
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                      attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                      scaling_factor: float, dropout_probability: float, is_training: bool):

    output, _ = _cross_fused_attn_fwd(q,
                                      kv,
                                      mask,
                                      seed,
                                      attn_bias_type=attn_bias_type,
                                      attn_mask_type=attn_mask_type,
                                      scaling_factor=scaling_factor,
                                      dropout_probability=dropout_probability,
                                      is_training=is_training)
241
242
243
    return output


244
245
def _cross_fused_attn_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
                          dropout_probability, is_training):
246
247
248
249
250
251
252
253
254

    q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
    q_cu_seqlen = jnp.cumsum(q_seqlen)
    q_cu_seqlen = jnp.hstack((0, q_cu_seqlen))

    kv_seqlen = jnp.sum(mask[:, :, 0, :] == 0, axis=(-1, -2), dtype=jnp.int32)
    kv_cu_seqlen = jnp.cumsum(kv_seqlen)
    kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen))

255
256
257
258
259
260
261
262
263
264
    output, softmax_aux = cross_fused_attn_fwd(q,
                                               kv,
                                               q_cu_seqlen,
                                               kv_cu_seqlen,
                                               seed,
                                               attn_bias_type=attn_bias_type.value,
                                               attn_mask_type=attn_mask_type.value,
                                               scaling_factor=scaling_factor,
                                               dropout_probability=dropout_probability,
                                               is_training=is_training)
265
266
267
    return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen)


268
269
def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                          is_training, ctx, grad):
270
271
272
273
    softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx

    doutput = grad

274
275
276
277
278
279
280
281
282
283
284
    grad_q, grad_kv = cross_fused_attn_bwd(q,
                                           kv,
                                           softmax_aux,
                                           doutput,
                                           q_cu_seqlen,
                                           kv_cu_seqlen,
                                           attn_bias_type=attn_bias_type.value,
                                           attn_mask_type=attn_mask_type.value,
                                           scaling_factor=scaling_factor,
                                           dropout_probability=dropout_probability,
                                           is_training=is_training)
285
286
287
288

    return grad_q, grad_kv, None, None


289
_cross_fused_attn.defvjp(_cross_fused_attn_fwd, _cross_fused_attn_bwd)