attention.py 37.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Shijie's avatar
Shijie committed
2
3
4
5
6
#
# See LICENSE for license information.
"""Attntion API"""

import math
7
import os
Shijie's avatar
Shijie committed
8
9
10
11
12
import warnings
from typing import Optional, Tuple, Union

import paddle
import paddle.nn.functional as F
13
import transformer_engine_paddle as tex
Shijie's avatar
Shijie committed
14

15
16
17
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .softmax import FusedScaleMaskSoftmax
18
19
from ..constants import (AttnTypes, TE_DType, AttnBiasType, AttnMaskType, FusedAttnBackend,
                         dist_group_type)
20
from ..cpp_extensions import (
Shijie's avatar
Shijie committed
21
22
23
24
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
Shijie's avatar
Shijie committed
25
26
    fused_attn_fwd,
    fused_attn_bwd,
27
    mask_to_cu_seqlens,
Shijie's avatar
Shijie committed
28
)
29
from ..distributed import get_tp_group_and_world_size, track_rng_state
30
from ..utils import attention_mask_func, divide
Tian Zheng's avatar
Tian Zheng committed
31
from ..recompute import recompute
Shijie's avatar
Shijie committed
32

33
34
35
__all__ = ["DotProductAttention", "MultiHeadAttention"]


Shijie's avatar
Shijie committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
    """
    Used to repeat the key and value states for GQA.
    The hidden states go from (batch, seqlen, num_gqa_groups, head_size)
    to (batch, seqlen, num_heads, head_size)
    """
    batch, seqlen, num_gqa_groups, head_size = hidden_states.shape
    if n_rep == 1:
        return hidden_states

    hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
    return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size])


Shijie's avatar
Shijie committed
50
51
52
53
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
54
55
56
    def forward(ctx, qkv, cu_seqlens, attn_bias, max_seqlen, attn_scale, qkv_dtype, dropout_p,
                set_zero, qkv_layout, attn_bias_type, attn_mask_type, is_training,
                fused_attention_backend):
Shijie's avatar
Shijie committed
57
        """Forward function for FusedAttention with packed QKV input"""
58
        out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
Shijie's avatar
Shijie committed
59
60
61
62
63
            qkv,
            cu_seqlens,
            is_training,
            max_seqlen,
            qkv_dtype,
64
            fused_attention_backend,
Shijie's avatar
Shijie committed
65
66
67
68
69
70
71
72
73
            attn_bias,
            attn_scale,
            dropout_p,
            set_zero,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
        )

74
        ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
75
76
77
78
79
80
81
82
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
83
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
84
85
86
87
88
89

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed QKV input"""
90
91
92
93
94
        qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor()
        dqkv, *rest = fused_attn_bwd_qkvpacked(qkv, cu_seqlens, rng_state, out, d_out, softmax_aux,
                                               ctx.fused_attention_backend, ctx.max_seqlen,
                                               ctx.qkv_dtype, ctx.attn_scale, ctx.dropout_p,
                                               ctx.set_zero, ctx.qkv_layout, ctx.attn_bias_type,
Shijie's avatar
Shijie committed
95
96
97
98
                                               ctx.attn_mask_type)

        # if no_bias, return dqkv
        if ctx.attn_bias_type == "no_bias":
99
            return (dqkv, None)
Shijie's avatar
Shijie committed
100
        # else, return (dqkv, dbias)
101
        return (dqkv, None, rest[0])
Shijie's avatar
Shijie committed
102
103
104
105
106
107


class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
108
109
110
    def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
                attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
                attn_mask_type, is_training, fused_attention_backend):
Shijie's avatar
Shijie committed
111
        """Forward function for FusedAttention with packed KV input"""
112
113
114
115
        out, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
            q, kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, qkv_dtype,
            fused_attention_backend, attn_bias, attn_scale, dropout_p, set_zero, qkv_layout,
            attn_bias_type, attn_mask_type)
Shijie's avatar
Shijie committed
116

117
        ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
118
119
120
121
122
123
124
125
126
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
127
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
128
129
130
131
132
133

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed KV input"""
134
        q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
Shijie's avatar
Shijie committed
135
        dq, dkv, *rest = fused_attn_bwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
136
137
138
139
140
                                                 d_out, softmax_aux, ctx.fused_attention_backend,
                                                 ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
                                                 ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
                                                 ctx.qkv_layout, ctx.attn_bias_type,
                                                 ctx.attn_mask_type)
Shijie's avatar
Shijie committed
141
142
143

        # if no_bias, return dq, dkv
        if ctx.attn_bias_type == "no_bias":
144
            return (dq, dkv, None, None)
Shijie's avatar
Shijie committed
145
        # else, return (dq, dkv, dbias)
146
        return (dq, dkv, None, None, rest[0])
Shijie's avatar
Shijie committed
147
148


Shijie's avatar
Shijie committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
class FusedAttnFunc(paddle.autograd.PyLayer):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
    def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
                attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
                attn_mask_type, is_training, fused_attention_backend):
        """Forward function for FusedAttention with separate Q, K, V tensors"""
        out, softmax_aux, rng_state = fused_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_kv,
                                                     is_training, max_seqlen_q, max_seqlen_kv,
                                                     qkv_dtype, fused_attention_backend, attn_bias,
                                                     attn_scale, dropout_p, set_zero, qkv_layout,
                                                     attn_bias_type, attn_mask_type)

        ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with separate Q, K, V tensors"""
        q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
        dq, dk, dv, *rest = fused_attn_bwd(q, k, v, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
                                           d_out, softmax_aux, ctx.fused_attention_backend,
                                           ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
                                           ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
                                           ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
        # if no_bias, return dq, dk, dv
        if ctx.attn_bias_type == "no_bias":
            return (dq, dk, dv, None, None)
        # else, return (dq, dk, dv, dbias)
        return (dq, dk, dv, None, None, rest[0])


Shijie's avatar
Shijie committed
193
class DotProductAttention(paddle.nn.Layer):
194
    """
Shijie's avatar
Shijie committed
195
196
197
198
199
200
201
202
203
204
205
    Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

        Argument :attr:`attention_mask` will be ignored in the `forward` call when
        :attr:`attn_mask_type` is set to `"causal"`.

    Parameters
    ----------
Shijie's avatar
Shijie committed
206
207
208
209
210
211
212
213
214
215
216
217
    num_attention_heads: int
            number of attention heads in the transformer layer.
    kv_channels: int
            number of channels in the key and value tensors.
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Shijie's avatar
Shijie committed
218
219
220
221
222
223
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
Shijie's avatar
Shijie committed
224
225
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
Shijie's avatar
Shijie committed
226
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
227
             backend to use for attention operation.
Shijie's avatar
Shijie committed
228
229
230
    """

    def __init__(self,
Shijie's avatar
Shijie committed
231
232
233
                 num_attention_heads: int,
                 kv_channels: int,
                 num_gqa_groups: Optional[int] = None,
Shijie's avatar
Shijie committed
234
235
236
                 attention_dropout: float = 0.1,
                 attn_mask_type: str = "causal",
                 attention_type: str = "self",
Shijie's avatar
Shijie committed
237
                 tp_size: int = 1,
Shijie's avatar
Shijie committed
238
239
240
241
242
243
                 backend: str = 'transformer_engine') -> None:
        super().__init__()

        self.attn_mask_type = attn_mask_type
        self.attention_dropout = attention_dropout
        self.attention_type = attention_type
Shijie's avatar
Shijie committed
244
245
246
247
248
249
250
        self.qkv_layout = "bshd_bshd_bshd"
        self.hidden_size_per_attention_head = kv_channels
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        self.tp_size = tp_size
        self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
        self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups
251
252
253

        self.backend = backend

Tim Moon's avatar
Tim Moon committed
254
        self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1")))
255

256
257
        if not self.use_fused_attention and backend == 'transformer_engine':
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
258
259
            self.backend = 'paddle'

Shijie's avatar
Shijie committed
260
261
262
263
264
265
266
267
        if self.backend != 'transformer_engine':
            self.scale_mask_softmax = FusedScaleMaskSoftmax(attn_mask_type,
                                                            attention_mask_func,
                                                            backend=self.backend)

    def forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
268
269
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

            Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
            is set to `"causal"`.


        Parameters
        ----------
        query_layer : paddle.Tensor
287
                      Query tensor.
Shijie's avatar
Shijie committed
288
289
290
291
        key_layer : paddle.Tensor
                      Key tensor.
        value_layer : paddle.Tensor
                      Value tensor.
Shijie's avatar
Shijie committed
292
        attention_mask : Optional[paddle.Tensor], default = `None`
293
                         Boolean tensor used to mask out softmax input when not using attention.
Shijie's avatar
Shijie committed
294
        core_attention_bias_type: str, default = `no_bias`
295
                                  only support no_bias type currently, {`no_bias`}
Shijie's avatar
Shijie committed
296
        core_attention_bias: Optional[paddle.Tensor], default = `None`
297
298
299
                             Bias tensor for Q * K.T
        set_zero: bool, default = `True`
                  Whether to use the fast path to set output tensors to 0 or not.
Shijie's avatar
Shijie committed
300
301
        """

Tim Moon's avatar
Tim Moon committed
302
303
        backend = self.backend

Shijie's avatar
Shijie committed
304
305
306
307
        assert (key_layer.shape == value_layer.shape), "Keys and values must have the same shape!"
        assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
               ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"

Tim Moon's avatar
Tim Moon committed
308
        if backend == 'transformer_engine':
309
            max_s_q = query_layer.shape[1]
Shijie's avatar
Shijie committed
310
            max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1]
311
312
            self.fused_attention_backend = tex.get_fused_attn_backend(
                TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
313
                tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
314
                AttnMaskType[self.attn_mask_type], self.attention_dropout, query_layer.shape[-2],
Shijie's avatar
Shijie committed
315
316
                key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], max_s_q,
                max_s_kv, query_layer.shape[-1])
317
318
319
320
321

            is_backend_avail = (self.fused_attention_backend in [
                FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]
            ])
            if is_backend_avail and self.use_fused_attention:
Shijie's avatar
Shijie committed
322
                return self._te_forward(query_layer, key_layer, value_layer, attention_mask,
323
324
                                        core_attention_bias_type, core_attention_bias, set_zero)
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
Tim Moon's avatar
Tim Moon committed
325
            backend = 'paddle'
326
327
            self.scale_mask_softmax = FusedScaleMaskSoftmax(self.attn_mask_type,
                                                            attention_mask_func,
Tim Moon's avatar
Tim Moon committed
328
329
                                                            backend=backend)
        if backend == 'paddle':
Shijie's avatar
Shijie committed
330
331
332
            if core_attention_bias_type != "no_bias":
                warnings.warn("Paddle backend dot product attention does not support bias yet. "
                              "Bias will be ignored.")
Shijie's avatar
Shijie committed
333
            return self._pd_forward(query_layer, key_layer, value_layer, attention_mask)
Tim Moon's avatar
Tim Moon committed
334
        raise AttributeError(f"Backend {backend} is not supported.")
Shijie's avatar
Shijie committed
335
336
337
338

    def _te_forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
339
340
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
341
342
343
344
345
346
347
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:

        if self.attention_type == "self":
Shijie's avatar
Shijie committed
348
349
350
351
            # self attention - q: [b, s, h, d]  kv: None
            assert (len(query_layer.shape) == 4 and len(key_layer.shape) == 4
                    and len(value_layer.shape)
                    == 4), "q,k,v shape must be [b, s, h, d] for dot product self attention"
Shijie's avatar
Shijie committed
352
            max_seqlen = query_layer.shape[1]
353
354
355
356
357
358
            if self.attn_mask_type == "causal" or attention_mask is None:
                cu_seqlens = paddle.arange(0, (query_layer.shape[0] + 1) * query_layer.shape[1],
                                           step=query_layer.shape[1],
                                           dtype='int32')
            else:
                cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False)
Shijie's avatar
Shijie committed
359
            qkv_dtype = TE_DType[query_layer.dtype]
360

Shijie's avatar
Shijie committed
361
362
363
364
365
366
367
            output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens,
                                         cu_seqlens, core_attention_bias, max_seqlen, max_seqlen,
                                         1.0 / self.norm_factor, qkv_dtype,
                                         self.attention_dropout if self.training else 0.0, set_zero,
                                         self.qkv_layout, core_attention_bias_type,
                                         self.attn_mask_type, self.training,
                                         self.fused_attention_backend)
Shijie's avatar
Shijie committed
368
        elif self.attention_type == "cross":
Shijie's avatar
Shijie committed
369
            # cross attention - q: [b, s_q, h, d]  k,v: [b, s_kv, h, d]
Shijie's avatar
Shijie committed
370
            assert (
Shijie's avatar
Shijie committed
371
372
373
                len(query_layer.shape) == 4 and len(key_layer.shape) == 4
                and len(value_layer.shape) == 4
            ), "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" \
Shijie's avatar
Shijie committed
374
                "for dot product cross attention"
375
376
            assert (attention_mask
                    is not None), "attention_mask must be provided for cross attention"
Shijie's avatar
Shijie committed
377
            max_seqlen_q = query_layer.shape[1]
Shijie's avatar
Shijie committed
378
            max_seqlen_kv = key_layer.shape[1]
Shijie's avatar
Shijie committed
379
380
            cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True)
            qkv_dtype = TE_DType[query_layer.dtype]
Shijie's avatar
Shijie committed
381
382
383
384
385
386
387
            output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens_q,
                                         cu_seqlens_kv, core_attention_bias, max_seqlen_q,
                                         max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype,
                                         self.attention_dropout if self.training else 0.0, set_zero,
                                         self.qkv_layout, core_attention_bias_type,
                                         self.attn_mask_type, self.training,
                                         self.fused_attention_backend)
Shijie's avatar
Shijie committed
388
389
390
391
392
393
394
        else:
            raise ValueError("attention_type must be one of ['self', 'cross']")
        return output

    def _pd_forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
395
396
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
397
398
        attention_mask: Optional[paddle.Tensor] = None,
    ) -> paddle.Tensor:
Shijie's avatar
Shijie committed
399
400
401
402

        q = query_layer
        k = repeat_kv(key_layer, self.num_queries_per_key_value)
        v = repeat_kv(value_layer, self.num_queries_per_key_value)
Shijie's avatar
Shijie committed
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

        q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
        k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
        v = paddle.transpose(x=v, perm=[0, 2, 1, 3])

        product = paddle.matmul(x=q * (1.0 / self.norm_factor), y=k, transpose_y=True)
        attention_probs = self.scale_mask_softmax(product, attention_mask, scale=None)

        if self.attention_dropout > 0:
            attention_probs = F.dropout(
                attention_probs,
                self.attention_dropout,
                training=self.training,
            )

        out = paddle.matmul(attention_probs, v)
        out = paddle.transpose(out, perm=[0, 2, 1, 3])    # [b, s, h, d]
        # out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
        return out


424
class MultiHeadAttention(paddle.nn.Layer):
425
426
427
    """
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.
Shijie's avatar
Shijie committed
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455

    Parameters
    ----------
    hidden_size: int
                    hidden size of the model.
    num_attention_heads: int
                    number of attention heads.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon: float, default = 1e-5
                          epsilon to use in the layer norm operations.
    weight_attr: Union[paddle.ParamAttr, None], default = `None`
                    paddle.ParamAttr object for the weight parameter.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = `None`
                    paddle.ParamAttr object for the bias parameter.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    params_dtype: Optional[paddle.dtype], default = `None`
                    data type for the weights and biases.
    return_layernorm_output: bool, default = `False`
                    whether to return the output of the layernorm operation.
    input_layernorm: bool, default = `False`
                    whether to apply layernorm to the input.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
    zero_centered_gamma: bool, default = `False`
                    whether to zero initialize the gamma of the layernorm operation.
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
456
457
             backend to use for attention operation. If set to 'paddle', a framework
             only no-FP8 path is executed with limited optimization.
Tian Zheng's avatar
Tian Zheng committed
458
459
460
461
462
463
464

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
465
466
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
Tian Zheng's avatar
Tian Zheng committed
467
468
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
Shijie's avatar
Shijie committed
469
470
471
472
473
474
475
476
    num_gqa_groups : int, default = `None`
                     number of GQA groups in the transformer layer.
                     Grouped Query Attention is described in
                     `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                     This only affects the keys and values, not the querys.
                     GQA-1 is equivalent to Multi-Query Attention
                     (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                     is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Tian Zheng's avatar
Tian Zheng committed
477
478
479
480
481
482
483
    rng_state_name : str, default = `local_seed`
                   Controls the rng state used for dropout on attention probs. The
                   specified rng should be set different seeds for different TP ranks.
                   It will be ignored if `set_parallel_mode` is False. The specified
                   name should be registered through
                   `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
                   .add(rng_state_name, seed)`.
Shijie's avatar
Shijie committed
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
        attn_mask_type: str = "causal",
        params_dtype: Optional[paddle.dtype] = None,
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        zero_centered_gamma: bool = False,
500
        set_parallel_mode: bool = False,
501
        sequence_parallel: bool = False,
502
        tp_group: Optional[dist_group_type] = None,
Shijie's avatar
Shijie committed
503
        num_gqa_groups: Optional[int] = None,
504
        rng_state_name: str = 'local_seed',
Shijie's avatar
Shijie committed
505
506
507
508
509
510
511
512
513
514
515
516
517
        backend: str = 'transformer_engine',
    ) -> None:
        super().__init__()
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.return_layernorm_output = return_layernorm_output
        self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
        self.weight_attr = weight_attr
        self.bias_attr = bias_attr
        self.attn_mask_type = attn_mask_type

        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"

518
519
520
        self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
                                                                  enable_tp=set_parallel_mode)
        self.tensor_parallel = self.tp_size > 1
521
        self.sequence_parallel = self.tensor_parallel and sequence_parallel
Shijie's avatar
Shijie committed
522
523
        self.hidden_size_per_attention_head = hidden_size // num_attention_heads
        self.num_attention_heads = num_attention_heads
524
        self.set_parallel_mode = set_parallel_mode
525
        self.rng_state_name = rng_state_name
Shijie's avatar
Shijie committed
526
527
        self.backend = backend

528
        self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
Shijie's avatar
Shijie committed
529
530
531
532
533
534
535
        self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
        assert (self.num_attention_heads % self.num_gqa_groups == 0
               ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (self.num_gqa_groups % self.tp_size == 0
               ), "The number of GQA groups must be divisible by tensor parallel size!"
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
        self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads)
536
537
        qkv_parallel_mode = "column" if set_parallel_mode else None

Shijie's avatar
Shijie committed
538
539
540
541
        if self.attention_type == "self":
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
Shijie's avatar
Shijie committed
542
                    hidden_size + 2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
543
544
545
546
547
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
548
                    parallel_mode=qkv_parallel_mode,
549
                    sequence_parallel=self.sequence_parallel,
550
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
551
552
553
554
555
                    backend=self.backend,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
Shijie's avatar
Shijie committed
556
                    hidden_size + 2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
557
558
                    self.weight_attr,
                    self.bias_attr,
559
                    parallel_mode=qkv_parallel_mode,
560
                    sequence_parallel=self.sequence_parallel,
561
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
562
563
564
565
566
567
568
569
570
571
572
573
574
                    backend=self.backend,
                )

        else:    # cross attention
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
                    hidden_size,
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
575
                    parallel_mode=qkv_parallel_mode,
576
                    sequence_parallel=self.sequence_parallel,
577
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
578
579
580
581
582
583
584
585
                    backend=self.backend,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
                    hidden_size,
                    self.weight_attr,
                    self.bias_attr,
586
                    parallel_mode=qkv_parallel_mode,
587
                    sequence_parallel=self.sequence_parallel,
588
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
589
590
591
592
                    backend=self.backend,
                )
            self.key_value = Linear(
                hidden_size,
Shijie's avatar
Shijie committed
593
                2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
594
595
                self.weight_attr,
                self.bias_attr,
596
                parallel_mode=qkv_parallel_mode,
597
                sequence_parallel=self.sequence_parallel,
598
                tp_group=self.tp_group,
Shijie's avatar
Shijie committed
599
600
601
602
603
                backend=self.backend,
            )

        # Attention.
        self.core_attention = DotProductAttention(
Shijie's avatar
Shijie committed
604
605
606
            self.num_attention_heads,
            self.hidden_size_per_attention_head,
            self.num_gqa_groups,
Shijie's avatar
Shijie committed
607
608
609
            attention_dropout,
            attn_mask_type=attn_mask_type,
            attention_type=self.attention_type,
Shijie's avatar
Shijie committed
610
            tp_size=self.tp_size,
Shijie's avatar
Shijie committed
611
612
613
614
615
616
617
618
619
            backend=self.backend,
        )

        # Linear
        self.proj = Linear(
            hidden_size,
            hidden_size,
            self.weight_attr,
            self.bias_attr,
620
            parallel_mode="row" if set_parallel_mode else None,
621
            sequence_parallel=self.sequence_parallel,
622
            tp_group=self.tp_group,
Shijie's avatar
Shijie committed
623
624
625
626
627
628
629
630
631
632
633
            backend=self.backend,
        )

    def forward(
        self,
        hidden_states: paddle.Tensor,
        attention_mask: Optional[paddle.Tensor] = None,
        encoder_output: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
Tian Zheng's avatar
Tian Zheng committed
634
        recompute_core_attention: bool = False,
635
        is_first_microbatch: Optional[bool] = None,
Shijie's avatar
Shijie committed
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
    ) -> Tuple[Union[paddle.Tensor, None], ...]:
        """
        MultiHeadAttention Layer.

        Parameters
        ----------
        hidden_states : paddle.Tensor
                        Input tensor.
        attention_mask : Optional[paddle.Tensor], default = `None`
                        Boolean tensor used to mask out softmax input when not using attention.
        encoder_output : Optional[paddle.Tensor], default = `None`
                        Output of the encoder layer.
        core_attention_bias_type: str, default = `no_bias`
                                only support no_bias type currently, {`no_bias`}
        core_attention_bias: Optional[paddle.Tensor], default = `None`
                    Bias tensor for Q * K.T
652
        set_zero: bool, default = `True`
Shijie's avatar
Shijie committed
653
                    Whether to use the fast path to set output tensors to 0 or not.
Tian Zheng's avatar
Tian Zheng committed
654
655
656
657
658
        recompute_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
659
660
661
662
663
664
665
666
667
668
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
Shijie's avatar
Shijie committed
669
670
671
672
673
        """

        if self.attn_mask_type != "causal" and attention_mask is not None:
            assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor"

674
675
676
677
678
679
680
681
682
683
684
685
        input_dim = len(hidden_states.shape)
        if input_dim == 2:
            # hidden_states: [b * s_q, hidden_size]
            # need to get max_seq_len from attention_mask
            assert attention_mask is not None
            max_seq_len = attention_mask.shape[-1]
        elif input_dim == 3:
            # hidden_states: [b, s_q, hidden_size]
            max_seq_len = hidden_states.shape[1]
        else:
            raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.")

Shijie's avatar
Shijie committed
686
687
        if self.attention_type == "self":
            if self.input_layernorm:
688
689
690
691
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
692
693
694
695
696
                if self.return_layernorm_output:
                    mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_qkv_layer = layernorm_qkv_outputs
            else:
697
698
699
700
                mixed_qkv_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
701

Shijie's avatar
Shijie committed
702
703
704
705
            num_queries_per_key_value = (self.num_attention_heads_per_partition //
                                         self.num_gqa_groups_per_partition)

            # [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d]
706
            mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
Shijie's avatar
Shijie committed
707
708
709
                -1, max_seq_len, (
                    num_queries_per_key_value +
                    2), self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
710
711
            ])

Shijie's avatar
Shijie committed
712
713
714
715
716
717
718
719
720
721
722
723
724
725
            # [b, s_q, (h/ng+2), ng, d]
            # --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d]
            query_layer, key_layer, value_layer = paddle.split(
                mixed_qkv_layer,
                num_or_sections=(num_queries_per_key_value, 1, 1),
                axis=2,
            )

            # query: -> [b, s, h, d]
            # key, value: -> [b, s, ng, d]
            query_layer, key_layer, value_layer = (x.reshape(
                shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head])
                                                   for x in (query_layer, key_layer, value_layer))

726
            with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
Tian Zheng's avatar
Tian Zheng committed
727
728
729
                if recompute_core_attention:
                    context_layer = recompute(
                        self.core_attention,
Shijie's avatar
Shijie committed
730
731
732
                        query_layer,
                        key_layer,
                        value_layer,
Tian Zheng's avatar
Tian Zheng committed
733
734
735
736
737
738
739
740
                        attention_mask,
                        core_attention_bias_type,
                        core_attention_bias,
                        set_zero,
                        use_reentrant=False,
                    )
                else:
                    context_layer = self.core_attention(
Shijie's avatar
Shijie committed
741
742
743
                        query_layer=query_layer,
                        key_layer=key_layer,
                        value_layer=value_layer,
Tian Zheng's avatar
Tian Zheng committed
744
745
746
747
748
                        attention_mask=attention_mask,
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        set_zero=set_zero,
                    )
Shijie's avatar
Shijie committed
749
750

        else:    # cross attention
751
752
753
754
            mixed_kv_layer = self.key_value(
                encoder_output,
                is_first_microbatch=is_first_microbatch,
            )
Shijie's avatar
Shijie committed
755
            # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
756
            mixed_kv_layer = mixed_kv_layer.reshape(shape=[
Shijie's avatar
Shijie committed
757
                0, 0, 2 * self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
758
            ])
Shijie's avatar
Shijie committed
759

Shijie's avatar
Shijie committed
760
761
762
763
764
765
766
767
            # [b, s_kv, 2 * ng, head_size]
            # --> 2 [b, s_kv, ng, head_size]
            key_layer, value_layer = paddle.split(
                mixed_kv_layer,
                num_or_sections=2,
                axis=2,
            )

Shijie's avatar
Shijie committed
768
            if self.input_layernorm:
769
770
771
772
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
773
774
775
776
777
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
778
779
780
781
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
782

Shijie's avatar
Shijie committed
783
            # [b, s, hidden_size] --> [b, s, h, d]
784
            query_layer = query_layer.reshape(shape=[
785
786
                -1, max_seq_len, self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head
787
            ])
788
            with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
Tian Zheng's avatar
Tian Zheng committed
789
790
791
792
                if recompute_core_attention:
                    context_layer = recompute(
                        self.core_attention,
                        query_layer,
Shijie's avatar
Shijie committed
793
794
                        key_layer,
                        value_layer,
Tian Zheng's avatar
Tian Zheng committed
795
796
797
798
799
800
801
802
803
                        attention_mask,
                        core_attention_bias_type,
                        core_attention_bias,
                        set_zero,
                        use_reentrant=False,
                    )
                else:
                    context_layer = self.core_attention(
                        query_layer=query_layer,
Shijie's avatar
Shijie committed
804
805
                        key_layer=key_layer,
                        value_layer=value_layer,
Tian Zheng's avatar
Tian Zheng committed
806
807
808
809
810
                        attention_mask=attention_mask,
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        set_zero=set_zero,
                    )
Shijie's avatar
Shijie committed
811

812
813
814
815
816
817
818
        if input_dim == 3:
            context_layer = paddle.reshape(
                context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]])
        else:    # input_dim == 2
            context_layer = paddle.reshape(context_layer,
                                           [-1, context_layer.shape[2] * context_layer.shape[3]])

Shijie's avatar
Shijie committed
819
        # Output. [b, s, hidden]
820
        attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch)
Shijie's avatar
Shijie committed
821
822
823
824

        if self.input_layernorm and self.return_layernorm_output:
            return attention_output, layernorm_output
        return attention_output