fused_attn.py 13.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""

from enum import Enum
from functools import partial
import jax
import jax.numpy as jnp

from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
13
from transformer_engine_jax import NVTE_QKV_Layout
14

15
from .cpp_extensions import FusedAttnHelper
16
17
from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
18
19
from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType
20
from .sharding import xmap_runner, extend_fsdp_sharding_meta
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)


class AttnBiasType(Enum):
    """Attention Bias Type."""
    NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
    PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
    POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS


class AttnMaskType(Enum):
    """Attention Mask Type."""
    NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
    PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
    CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK


40
41
42
43
44
45
46
class QKVLayout(Enum):
    """QKV layout"""
    BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
    BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD


def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type,
47
48
49
50
                                   dropout_probability, max_seqlen_q, max_seqlen_kv, head_dim):
    """
    To check whether the fused attention kernel is available
    """
51
52
    return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value,
                           attn_mask_type.value, dropout_probability, max_seqlen_q, max_seqlen_kv,
53
54
55
                           head_dim).is_fused_attn_kernel_available()


56
57
58
def self_fused_attn(qkv: jnp.ndarray,
                    bias: jnp.ndarray,
                    mask: jnp.ndarray,
59
                    seed: jnp.ndarray,
60
61
62
63
64
65
66
67
68
69
                    attn_bias_type: AttnBiasType,
                    attn_mask_type: AttnMaskType,
                    scaling_factor: float,
                    dropout_probability: float,
                    is_training: bool,
                    sharding_type: ShardingType = ShardingType.SINGLE):
    """
    Self fused attention wrapper
    """
    assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
70
        "self_fused_attn does not support row-split tensor parallelism currently."
71
72

    if sharding_type is ShardingType.SINGLE:
73
74
75
76
77
78
79
80
81
        output = _self_fused_attn(qkv,
                                  bias,
                                  mask,
                                  seed,
                                  attn_bias_type=attn_bias_type,
                                  attn_mask_type=attn_mask_type,
                                  scaling_factor=scaling_factor,
                                  dropout_probability=dropout_probability,
                                  is_training=is_training)
82
83
84
85
    else:
        dp_axis_name = "batch"
        tp_axis_name = "model"

86
        inputs = [qkv, bias, mask, seed]
87
88
89
90
        batch, seqlen, _, num_head, head_dim = qkv.shape
        output_shape = [batch, seqlen, num_head, head_dim]
        sharding_meta = get_fused_attn_sharding_meta(
            sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
91
92
            dp_dims=([0, None, 0, 0], [0]),
            tp_dims=([3, 1, None, 0], [2]),
93
94
            dp_axis_name=dp_axis_name,
            tp_axis_name=tp_axis_name)
95
        sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
96
97
98
99
100

        inputs_ = tuple(
            jnp.reshape(x, new_shape) if x is not None else None
            for x, new_shape in zip(inputs, sharding_meta.input_shapes))

101
102
103
104
105
106
        partial_self_fused_attn = partial(_self_fused_attn,
                                          attn_bias_type=attn_bias_type,
                                          attn_mask_type=attn_mask_type,
                                          scaling_factor=scaling_factor,
                                          dropout_probability=dropout_probability,
                                          is_training=is_training)
107

108
        output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes,
109
                              sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
110

111
        output = jnp.reshape(output_, sharding_meta.output_shapes)
112
113
114
115
116

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
117
118
119
120
121
122
123
124
125
126
127
128
def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                     attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                     scaling_factor: float, dropout_probability: float, is_training: bool):
    output, _ = _self_fused_attn_fwd(qkv,
                                     bias,
                                     mask,
                                     seed,
                                     attn_bias_type=attn_bias_type,
                                     attn_mask_type=attn_mask_type,
                                     scaling_factor=scaling_factor,
                                     dropout_probability=dropout_probability,
                                     is_training=is_training)
129
130
131
    return output


132
133
def _self_fused_attn_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
                         dropout_probability, is_training):
134
135
136
137
138

    seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
    cu_seqlen = jnp.cumsum(seqlen)
    cu_seqlen = jnp.hstack((0, cu_seqlen))

139
140
141
142
143
144
145
146
147
148
    output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
                                                         bias,
                                                         cu_seqlen,
                                                         seed,
                                                         attn_bias_type=attn_bias_type.value,
                                                         attn_mask_type=attn_mask_type.value,
                                                         scaling_factor=scaling_factor,
                                                         dropout_probability=dropout_probability,
                                                         is_training=is_training)
    return output, (qkv, softmax_aux, rng_state, output, cu_seqlen)
149
150


151
152
153
def _self_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                         is_training, ctx, grad):
    qkv, softmax_aux, rng_state, output, cu_seqlen = ctx
154
155
156

    doutput = grad

157
158
159
160
161
162
163
164
165
166
167
168
169
170
    grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
                                              softmax_aux,
                                              rng_state,
                                              output,
                                              doutput,
                                              cu_seqlen,
                                              attn_bias_type=attn_bias_type.value,
                                              attn_mask_type=attn_mask_type.value,
                                              scaling_factor=scaling_factor,
                                              dropout_probability=dropout_probability,
                                              is_training=is_training)

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        grad_bias = None
171
172
173
174

    return grad_qkv, grad_bias, None, None


175
_self_fused_attn.defvjp(_self_fused_attn_fwd, _self_fused_attn_bwd)
176
177
178
179
180


def cross_fused_attn(q: jnp.ndarray,
                     kv: jnp.ndarray,
                     mask: jnp.ndarray,
181
                     seed: jnp.ndarray,
182
183
184
185
186
187
188
189
190
191
                     attn_bias_type: AttnBiasType,
                     attn_mask_type: AttnMaskType,
                     scaling_factor: float,
                     dropout_probability: float,
                     is_training: bool,
                     sharding_type: ShardingType = ShardingType.SINGLE):
    """
    Cross multi-head attention wrapper
    """
    assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
192
        "cross_fused_attn does not support row-split tensor parallelism currently."
193
194

    if sharding_type is ShardingType.SINGLE:
195
196
197
198
199
200
201
202
203
        output = _cross_fused_attn(q,
                                   kv,
                                   mask,
                                   seed,
                                   attn_bias_type=attn_bias_type,
                                   attn_mask_type=attn_mask_type,
                                   scaling_factor=scaling_factor,
                                   dropout_probability=dropout_probability,
                                   is_training=is_training)
204
205
206
207
    else:
        dp_axis_name = "batch"
        tp_axis_name = "model"

208
        inputs = [q, kv, mask, seed]
209
210
211
212
213
214
215
        output_shape = q.shape
        sharding_meta = get_fused_attn_sharding_meta(
            sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
            dp_dims=([0, 0, 0, None], [0]),
            tp_dims=([2, 3, None, None], [2]),
            dp_axis_name=dp_axis_name,
            tp_axis_name=tp_axis_name)
216
        sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
217
218
219
220
221

        inputs_ = tuple(
            jnp.reshape(x, new_shape) if x is not None else None
            for x, new_shape in zip(inputs, sharding_meta.input_shapes))

222
223
224
225
226
227
        partial_cross_fused_attn = partial(_cross_fused_attn,
                                           attn_bias_type=attn_bias_type,
                                           attn_mask_type=attn_mask_type,
                                           scaling_factor=scaling_factor,
                                           dropout_probability=dropout_probability,
                                           is_training=is_training)
228

229
        output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes,
230
                              sharding_meta.out_axes, sharding_meta.axis_resources, inputs_)
231

232
        output = jnp.reshape(output_, sharding_meta.output_shapes)
233
234
235
236
237

    return output


@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
238
239
240
241
242
243
244
245
246
247
248
249
250
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
                      attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
                      scaling_factor: float, dropout_probability: float, is_training: bool):

    output, _ = _cross_fused_attn_fwd(q,
                                      kv,
                                      mask,
                                      seed,
                                      attn_bias_type=attn_bias_type,
                                      attn_mask_type=attn_mask_type,
                                      scaling_factor=scaling_factor,
                                      dropout_probability=dropout_probability,
                                      is_training=is_training)
251
252
253
    return output


254
255
def _cross_fused_attn_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
                          dropout_probability, is_training):
256
257
258
259
260
261
262
263
264

    q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
    q_cu_seqlen = jnp.cumsum(q_seqlen)
    q_cu_seqlen = jnp.hstack((0, q_cu_seqlen))

    kv_seqlen = jnp.sum(mask[:, :, 0, :] == 0, axis=(-1, -2), dtype=jnp.int32)
    kv_cu_seqlen = jnp.cumsum(kv_seqlen)
    kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen))

265
266
267
268
269
270
271
272
273
274
    output, softmax_aux = cross_fused_attn_fwd(q,
                                               kv,
                                               q_cu_seqlen,
                                               kv_cu_seqlen,
                                               seed,
                                               attn_bias_type=attn_bias_type.value,
                                               attn_mask_type=attn_mask_type.value,
                                               scaling_factor=scaling_factor,
                                               dropout_probability=dropout_probability,
                                               is_training=is_training)
275
276
277
    return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen)


278
279
def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                          is_training, ctx, grad):
280
281
282
283
    softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx

    doutput = grad

284
285
286
287
288
289
290
291
292
293
294
    grad_q, grad_kv = cross_fused_attn_bwd(q,
                                           kv,
                                           softmax_aux,
                                           doutput,
                                           q_cu_seqlen,
                                           kv_cu_seqlen,
                                           attn_bias_type=attn_bias_type.value,
                                           attn_mask_type=attn_mask_type.value,
                                           scaling_factor=scaling_factor,
                                           dropout_probability=dropout_probability,
                                           is_training=is_training)
295
296
297
298

    return grad_q, grad_kv, None, None


299
_cross_fused_attn.defvjp(_cross_fused_attn_fwd, _cross_fused_attn_bwd)