attention.py 42.1 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Shijie's avatar
Shijie committed
2
3
4
5
6
#
# See LICENSE for license information.
"""Attntion API"""

import math
7
import os
Shijie's avatar
Shijie committed
8
9
10
11
12
import warnings
from typing import Optional, Tuple, Union

import paddle
import paddle.nn.functional as F
13

14
15
16
17
try:
    from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
    fused_rotary_position_embedding = None
18
from transformer_engine import transformer_engine_paddle as tex
Shijie's avatar
Shijie committed
19

20
21
22
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .softmax import FusedScaleMaskSoftmax
23
24
25
26
27
28
29
30
from ..constants import (
    AttnTypes,
    TE_DType,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
    dist_group_type,
)
31
from ..cpp_extensions import (
Shijie's avatar
Shijie committed
32
33
34
35
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
Shijie's avatar
Shijie committed
36
37
    fused_attn_fwd,
    fused_attn_bwd,
38
    mask_to_cu_seqlens,
Shijie's avatar
Shijie committed
39
)
40
from ..distributed import get_tp_group_and_world_size, track_rng_state
41
from ..utils import attention_mask_func, divide
Tian Zheng's avatar
Tian Zheng committed
42
from ..recompute import recompute
Shijie's avatar
Shijie committed
43

44
__all__ = ["DotProductAttention", "MultiHeadAttention", "RotaryPositionEmbedding"]
45
46


Shijie's avatar
Shijie committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
    """
    Used to repeat the key and value states for GQA.
    The hidden states go from (batch, seqlen, num_gqa_groups, head_size)
    to (batch, seqlen, num_heads, head_size)
    """
    batch, seqlen, num_gqa_groups, head_size = hidden_states.shape
    if n_rep == 1:
        return hidden_states

    hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
    return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size])


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class RotaryPositionEmbedding(paddle.nn.Layer):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """

    def __init__(
        self,
        dim: int,
        max_position_embeddings: int,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
        max_position_embeddings: int
            max_position_embeddings before position interpolation
        """
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
82
83
84
        self.inv_freq = 1.0 / (
            10000 ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / self.dim)
        )
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        self._set_cos_sin_cache(seq_len=max_position_embeddings)

    def _set_cos_sin_cache(self, seq_len):
        self.max_seq_len_cached = seq_len
        # [seq_len]
        t = paddle.arange(seq_len, dtype="float32")
        # [seq_len, dim/2]
        freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
        # [seq_len, dim]
        emb = paddle.concat([freqs, freqs], axis=-1)
        # [1, seqlen, 1, dim]
        self.cos_cached = emb.cos()[None, :, None, :]
        self.sin_cached = emb.sin()[None, :, None, :]

    def forward(self, max_seq_len: int):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        """
        cos = self.cos_cached[:, :, :max_seq_len, ...]
        sin = self.sin_cached[:, :, :max_seq_len, ...]
        return (cos, sin)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
115
116
117
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return paddle.concat([-x2, x1], axis=-1)  # shape is the same as x
118
119
120
121
122
123
124


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
    """Applies rotary positional embedding to the input."""

    if position_ids is None:
        # Note: Only for LlamaForCausalLMPipe model pretraining
125
126
        cos = cos[:, : q.shape[1], :, :]  # [bs, seq_len, 1, dim]
        sin = sin[:, : q.shape[1], :, :]  # [bs, seq_len, 1, dim]
127
    else:
128
129
130
131
        cos = cos.squeeze(axis=[0, 2])  # [seq_len, dim]
        sin = sin.squeeze(axis=[0, 2])  # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]
        sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]
132
133
134
135
136
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


Shijie's avatar
Shijie committed
137
138
139
140
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        attn_bias,
        max_seqlen,
        attn_scale,
        qkv_dtype,
        dropout_p,
        set_zero,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
        is_training,
        fused_attention_backend,
    ):
Shijie's avatar
Shijie committed
157
        """Forward function for FusedAttention with packed QKV input"""
158
        out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
Shijie's avatar
Shijie committed
159
160
161
162
163
            qkv,
            cu_seqlens,
            is_training,
            max_seqlen,
            qkv_dtype,
164
            fused_attention_backend,
Shijie's avatar
Shijie committed
165
166
167
168
169
170
171
172
173
            attn_bias,
            attn_scale,
            dropout_p,
            set_zero,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
        )

174
        ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
175
176
177
178
179
180
181
182
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
183
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
184
185
186
187
188
189

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed QKV input"""
190
        qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor()
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        dqkv, *rest = fused_attn_bwd_qkvpacked(
            qkv,
            cu_seqlens,
            rng_state,
            out,
            d_out,
            softmax_aux,
            ctx.fused_attention_backend,
            ctx.max_seqlen,
            ctx.qkv_dtype,
            ctx.attn_scale,
            ctx.dropout_p,
            ctx.set_zero,
            ctx.qkv_layout,
            ctx.attn_bias_type,
            ctx.attn_mask_type,
        )
Shijie's avatar
Shijie committed
208
209
210

        # if no_bias, return dqkv
        if ctx.attn_bias_type == "no_bias":
211
            return (dqkv, None)
Shijie's avatar
Shijie committed
212
        # else, return (dqkv, dbias)
213
        return (dqkv, None, rest[0])
Shijie's avatar
Shijie committed
214
215
216
217
218
219


class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_kv,
        attn_bias,
        max_seqlen_q,
        max_seqlen_kv,
        attn_scale,
        qkv_dtype,
        dropout_p,
        set_zero,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
        is_training,
        fused_attention_backend,
    ):
Shijie's avatar
Shijie committed
239
        """Forward function for FusedAttention with packed KV input"""
240
        out, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
            q,
            kv,
            cu_seqlens_q,
            cu_seqlens_kv,
            is_training,
            max_seqlen_q,
            max_seqlen_kv,
            qkv_dtype,
            fused_attention_backend,
            attn_bias,
            attn_scale,
            dropout_p,
            set_zero,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
        )
Shijie's avatar
Shijie committed
258

259
        ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
260
261
262
263
264
265
266
267
268
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
269
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
270
271
272
273
274
275

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed KV input"""
276
        q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        dq, dkv, *rest = fused_attn_bwd_kvpacked(
            q,
            kv,
            cu_seqlens_q,
            cu_seqlens_kv,
            rng_state,
            out,
            d_out,
            softmax_aux,
            ctx.fused_attention_backend,
            ctx.max_seqlen_q,
            ctx.max_seqlen_kv,
            ctx.qkv_dtype,
            ctx.attn_scale,
            ctx.dropout_p,
            ctx.set_zero,
            ctx.qkv_layout,
            ctx.attn_bias_type,
            ctx.attn_mask_type,
        )
Shijie's avatar
Shijie committed
297
298
299

        # if no_bias, return dq, dkv
        if ctx.attn_bias_type == "no_bias":
300
            return (dq, dkv, None, None)
Shijie's avatar
Shijie committed
301
        # else, return (dq, dkv, dbias)
302
        return (dq, dkv, None, None, rest[0])
Shijie's avatar
Shijie committed
303
304


Shijie's avatar
Shijie committed
305
306
307
308
class FusedAttnFunc(paddle.autograd.PyLayer):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        attn_bias,
        max_seqlen_q,
        max_seqlen_kv,
        attn_scale,
        qkv_dtype,
        dropout_p,
        set_zero,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
        is_training,
        fused_attention_backend,
    ):
Shijie's avatar
Shijie committed
329
        """Forward function for FusedAttention with separate Q, K, V tensors"""
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        out, softmax_aux, rng_state = fused_attn_fwd(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_kv,
            is_training,
            max_seqlen_q,
            max_seqlen_kv,
            qkv_dtype,
            fused_attention_backend,
            attn_bias,
            attn_scale,
            dropout_p,
            set_zero,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
        )
Shijie's avatar
Shijie committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367

        ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with separate Q, K, V tensors"""
        q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
        dq, dk, dv, *rest = fused_attn_bwd(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_kv,
            rng_state,
            out,
            d_out,
            softmax_aux,
            ctx.fused_attention_backend,
            ctx.max_seqlen_q,
            ctx.max_seqlen_kv,
            ctx.qkv_dtype,
            ctx.attn_scale,
            ctx.dropout_p,
            ctx.set_zero,
            ctx.qkv_layout,
            ctx.attn_bias_type,
            ctx.attn_mask_type,
        )
Shijie's avatar
Shijie committed
389
390
391
392
393
394
395
        # if no_bias, return dq, dk, dv
        if ctx.attn_bias_type == "no_bias":
            return (dq, dk, dv, None, None)
        # else, return (dq, dk, dv, dbias)
        return (dq, dk, dv, None, None, rest[0])


Shijie's avatar
Shijie committed
396
class DotProductAttention(paddle.nn.Layer):
397
    """
Shijie's avatar
Shijie committed
398
399
400
401
402
403
404
405
406
407
408
    Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

        Argument :attr:`attention_mask` will be ignored in the `forward` call when
        :attr:`attn_mask_type` is set to `"causal"`.

    Parameters
    ----------
Shijie's avatar
Shijie committed
409
410
411
412
413
414
415
416
417
418
419
420
    num_attention_heads: int
            number of attention heads in the transformer layer.
    kv_channels: int
            number of channels in the key and value tensors.
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Shijie's avatar
Shijie committed
421
422
423
424
425
426
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
Shijie's avatar
Shijie committed
427
428
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
Shijie's avatar
Shijie committed
429
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
430
             backend to use for attention operation.
Shijie's avatar
Shijie committed
431
432
    """

433
434
435
436
437
438
439
440
441
442
443
    def __init__(
        self,
        num_attention_heads: int,
        kv_channels: int,
        num_gqa_groups: Optional[int] = None,
        attention_dropout: float = 0.1,
        attn_mask_type: str = "causal",
        attention_type: str = "self",
        tp_size: int = 1,
        backend: str = "transformer_engine",
    ) -> None:
Shijie's avatar
Shijie committed
444
445
446
447
448
        super().__init__()

        self.attn_mask_type = attn_mask_type
        self.attention_dropout = attention_dropout
        self.attention_type = attention_type
Shijie's avatar
Shijie committed
449
450
451
452
        self.qkv_layout = "bshd_bshd_bshd"
        self.hidden_size_per_attention_head = kv_channels
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        self.tp_size = tp_size
453
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
Shijie's avatar
Shijie committed
454
455
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
        self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups
456
457
458

        self.backend = backend

Tim Moon's avatar
Tim Moon committed
459
        self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1")))
460

461
        if not self.use_fused_attention and backend == "transformer_engine":
462
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
463
            self.backend = "paddle"
464

465
466
467
468
        if self.backend != "transformer_engine":
            self.scale_mask_softmax = FusedScaleMaskSoftmax(
                attn_mask_type, attention_mask_func, backend=self.backend
            )
Shijie's avatar
Shijie committed
469
470
471
472

    def forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
473
474
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

            Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
            is set to `"causal"`.


        Parameters
        ----------
        query_layer : paddle.Tensor
492
                      Query tensor.
Shijie's avatar
Shijie committed
493
494
495
496
        key_layer : paddle.Tensor
                      Key tensor.
        value_layer : paddle.Tensor
                      Value tensor.
Shijie's avatar
Shijie committed
497
        attention_mask : Optional[paddle.Tensor], default = `None`
498
                         Boolean tensor used to mask out softmax input when not using attention.
Shijie's avatar
Shijie committed
499
        core_attention_bias_type: str, default = `no_bias`
500
                                  only support no_bias type currently, {`no_bias`}
Shijie's avatar
Shijie committed
501
        core_attention_bias: Optional[paddle.Tensor], default = `None`
502
503
504
                             Bias tensor for Q * K.T
        set_zero: bool, default = `True`
                  Whether to use the fast path to set output tensors to 0 or not.
Shijie's avatar
Shijie committed
505
506
        """

Tim Moon's avatar
Tim Moon committed
507
508
        backend = self.backend

509
510
511
512
        assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
        assert (
            key_layer.shape[-2] == self.num_gqa_groups_per_partition
        ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
Shijie's avatar
Shijie committed
513

514
        if backend == "transformer_engine":
515
            max_s_q = query_layer.shape[1]
Shijie's avatar
Shijie committed
516
            max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1]
517
            self.fused_attention_backend = tex.get_fused_attn_backend(
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
                TE_DType[query_layer.dtype],
                TE_DType[query_layer.dtype],
                tex.get_nvte_qkv_layout(self.qkv_layout),
                AttnBiasType[core_attention_bias_type],
                AttnMaskType[self.attn_mask_type],
                self.attention_dropout,
                query_layer.shape[-2],
                key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2],
                max_s_q,
                max_s_kv,
                query_layer.shape[-1],
            )

            is_backend_avail = self.fused_attention_backend in [
                FusedAttnBackend["F16_max512_seqlen"],
                FusedAttnBackend["F16_arbitrary_seqlen"],
            ]
535
            if is_backend_avail and self.use_fused_attention:
536
537
538
539
540
541
542
543
544
                return self._te_forward(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask,
                    core_attention_bias_type,
                    core_attention_bias,
                    set_zero,
                )
545
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
546
547
548
549
550
            backend = "paddle"
            self.scale_mask_softmax = FusedScaleMaskSoftmax(
                self.attn_mask_type, attention_mask_func, backend=backend
            )
        if backend == "paddle":
Shijie's avatar
Shijie committed
551
            if core_attention_bias_type != "no_bias":
552
553
554
555
                warnings.warn(
                    "Paddle backend dot product attention does not support bias yet. "
                    "Bias will be ignored."
                )
Shijie's avatar
Shijie committed
556
            return self._pd_forward(query_layer, key_layer, value_layer, attention_mask)
Tim Moon's avatar
Tim Moon committed
557
        raise AttributeError(f"Backend {backend} is not supported.")
Shijie's avatar
Shijie committed
558
559
560
561

    def _te_forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
562
563
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
564
565
566
567
568
569
570
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:

        if self.attention_type == "self":
Shijie's avatar
Shijie committed
571
            # self attention - q: [b, s, h, d]  kv: None
572
573
574
575
576
            assert (
                len(query_layer.shape) == 4
                and len(key_layer.shape) == 4
                and len(value_layer.shape) == 4
            ), "q,k,v shape must be [b, s, h, d] for dot product self attention"
Shijie's avatar
Shijie committed
577
            max_seqlen = query_layer.shape[1]
578
            if self.attn_mask_type == "causal" or attention_mask is None:
579
580
581
582
583
584
                cu_seqlens = paddle.arange(
                    0,
                    (query_layer.shape[0] + 1) * query_layer.shape[1],
                    step=query_layer.shape[1],
                    dtype="int32",
                )
585
586
            else:
                cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False)
Shijie's avatar
Shijie committed
587
            qkv_dtype = TE_DType[query_layer.dtype]
588

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
            output = FusedAttnFunc.apply(
                query_layer,
                key_layer,
                value_layer,
                cu_seqlens,
                cu_seqlens,
                core_attention_bias,
                max_seqlen,
                max_seqlen,
                1.0 / self.norm_factor,
                qkv_dtype,
                self.attention_dropout if self.training else 0.0,
                set_zero,
                self.qkv_layout,
                core_attention_bias_type,
                self.attn_mask_type,
                self.training,
                self.fused_attention_backend,
            )
Shijie's avatar
Shijie committed
608
        elif self.attention_type == "cross":
Shijie's avatar
Shijie committed
609
            # cross attention - q: [b, s_q, h, d]  k,v: [b, s_kv, h, d]
Shijie's avatar
Shijie committed
610
            assert (
611
612
                len(query_layer.shape) == 4
                and len(key_layer.shape) == 4
Shijie's avatar
Shijie committed
613
                and len(value_layer.shape) == 4
614
615
            ), (
                "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]"
Shijie's avatar
Shijie committed
616
                "for dot product cross attention"
617
618
            )
            assert attention_mask is not None, "attention_mask must be provided for cross attention"
Shijie's avatar
Shijie committed
619
            max_seqlen_q = query_layer.shape[1]
Shijie's avatar
Shijie committed
620
            max_seqlen_kv = key_layer.shape[1]
Shijie's avatar
Shijie committed
621
622
            cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True)
            qkv_dtype = TE_DType[query_layer.dtype]
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
            output = FusedAttnFunc.apply(
                query_layer,
                key_layer,
                value_layer,
                cu_seqlens_q,
                cu_seqlens_kv,
                core_attention_bias,
                max_seqlen_q,
                max_seqlen_kv,
                1.0 / self.norm_factor,
                qkv_dtype,
                self.attention_dropout if self.training else 0.0,
                set_zero,
                self.qkv_layout,
                core_attention_bias_type,
                self.attn_mask_type,
                self.training,
                self.fused_attention_backend,
            )
Shijie's avatar
Shijie committed
642
643
644
645
646
647
648
        else:
            raise ValueError("attention_type must be one of ['self', 'cross']")
        return output

    def _pd_forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
649
650
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
651
652
        attention_mask: Optional[paddle.Tensor] = None,
    ) -> paddle.Tensor:
Shijie's avatar
Shijie committed
653
654
655
656

        q = query_layer
        k = repeat_kv(key_layer, self.num_queries_per_key_value)
        v = repeat_kv(value_layer, self.num_queries_per_key_value)
Shijie's avatar
Shijie committed
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672

        q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
        k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
        v = paddle.transpose(x=v, perm=[0, 2, 1, 3])

        product = paddle.matmul(x=q * (1.0 / self.norm_factor), y=k, transpose_y=True)
        attention_probs = self.scale_mask_softmax(product, attention_mask, scale=None)

        if self.attention_dropout > 0:
            attention_probs = F.dropout(
                attention_probs,
                self.attention_dropout,
                training=self.training,
            )

        out = paddle.matmul(attention_probs, v)
673
        out = paddle.transpose(out, perm=[0, 2, 1, 3])  # [b, s, h, d]
Shijie's avatar
Shijie committed
674
675
676
677
        # out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
        return out


678
class MultiHeadAttention(paddle.nn.Layer):
679
680
681
    """
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.
Shijie's avatar
Shijie committed
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706

    Parameters
    ----------
    hidden_size: int
                    hidden size of the model.
    num_attention_heads: int
                    number of attention heads.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon: float, default = 1e-5
                          epsilon to use in the layer norm operations.
    weight_attr: Union[paddle.ParamAttr, None], default = `None`
                    paddle.ParamAttr object for the weight parameter.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = `None`
                    paddle.ParamAttr object for the bias parameter.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    params_dtype: Optional[paddle.dtype], default = `None`
                    data type for the weights and biases.
    return_layernorm_output: bool, default = `False`
                    whether to return the output of the layernorm operation.
    input_layernorm: bool, default = `False`
                    whether to apply layernorm to the input.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
707
708
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
Shijie's avatar
Shijie committed
709
710
711
    zero_centered_gamma: bool, default = `False`
                    whether to zero initialize the gamma of the layernorm operation.
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
712
713
             backend to use for attention operation. If set to 'paddle', a framework
             only no-FP8 path is executed with limited optimization.
Tian Zheng's avatar
Tian Zheng committed
714
715
716
717
718
719
720

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
721
722
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
Tian Zheng's avatar
Tian Zheng committed
723
724
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
Shijie's avatar
Shijie committed
725
726
727
728
729
730
731
732
    num_gqa_groups : int, default = `None`
                     number of GQA groups in the transformer layer.
                     Grouped Query Attention is described in
                     `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                     This only affects the keys and values, not the querys.
                     GQA-1 is equivalent to Multi-Query Attention
                     (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                     is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Tian Zheng's avatar
Tian Zheng committed
733
734
735
736
737
738
739
    rng_state_name : str, default = `local_seed`
                   Controls the rng state used for dropout on attention probs. The
                   specified rng should be set different seeds for different TP ranks.
                   It will be ignored if `set_parallel_mode` is False. The specified
                   name should be registered through
                   `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
                   .add(rng_state_name, seed)`.
Shijie's avatar
Shijie committed
740
741
742
743
744
745
746
747
748
749

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.

Shijie's avatar
Shijie committed
750
751
752
753
754
755
756
757
758
759
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
760
        max_sequence_length: Optional[int] = None,
Shijie's avatar
Shijie committed
761
762
763
764
765
        attn_mask_type: str = "causal",
        params_dtype: Optional[paddle.dtype] = None,
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
766
        normalization: str = "LayerNorm",
Shijie's avatar
Shijie committed
767
        zero_centered_gamma: bool = False,
768
        set_parallel_mode: bool = False,
769
        sequence_parallel: bool = False,
770
        tp_group: Optional[dist_group_type] = None,
Shijie's avatar
Shijie committed
771
        num_gqa_groups: Optional[int] = None,
Shijie's avatar
Shijie committed
772
        fuse_wgrad_accumulation: bool = False,
773
774
        rng_state_name: str = "local_seed",
        backend: str = "transformer_engine",
Shijie's avatar
Shijie committed
775
776
777
778
779
780
    ) -> None:
        super().__init__()
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.return_layernorm_output = return_layernorm_output
        self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
781
        self.max_sequence_length = max_sequence_length
Shijie's avatar
Shijie committed
782
783
784
785
786
787
        self.weight_attr = weight_attr
        self.bias_attr = bias_attr
        self.attn_mask_type = attn_mask_type

        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"

788
789
790
        self.tp_group, self.tp_size = get_tp_group_and_world_size(
            tp_group, enable_tp=set_parallel_mode
        )
791
        self.tensor_parallel = self.tp_size > 1
792
        self.sequence_parallel = self.tensor_parallel and sequence_parallel
Shijie's avatar
Shijie committed
793
794
        self.hidden_size_per_attention_head = hidden_size // num_attention_heads
        self.num_attention_heads = num_attention_heads
795
        self.set_parallel_mode = set_parallel_mode
796
        self.rng_state_name = rng_state_name
Shijie's avatar
Shijie committed
797
798
        self.backend = backend

799
        self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
800
801
802
803
804
805
806
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            self.num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % self.tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
Shijie's avatar
Shijie committed
807
808
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
        self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads)
809
810
        qkv_parallel_mode = "column" if set_parallel_mode else None

Shijie's avatar
Shijie committed
811
812
813
814
        if self.attention_type == "self":
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
Shijie's avatar
Shijie committed
815
                    hidden_size + 2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
816
817
818
819
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
820
                    normalization=normalization,
Shijie's avatar
Shijie committed
821
                    zero_centered_gamma=zero_centered_gamma,
822
                    parallel_mode=qkv_parallel_mode,
823
                    sequence_parallel=self.sequence_parallel,
824
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
825
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
826
827
828
829
830
                    backend=self.backend,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
Shijie's avatar
Shijie committed
831
                    hidden_size + 2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
832
833
                    self.weight_attr,
                    self.bias_attr,
834
                    parallel_mode=qkv_parallel_mode,
835
                    sequence_parallel=self.sequence_parallel,
836
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
837
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
838
839
840
                    backend=self.backend,
                )

841
        else:  # cross attention
Shijie's avatar
Shijie committed
842
843
844
845
846
847
848
849
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
                    hidden_size,
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
850
                    normalization=normalization,
Shijie's avatar
Shijie committed
851
                    zero_centered_gamma=zero_centered_gamma,
852
                    parallel_mode=qkv_parallel_mode,
853
                    sequence_parallel=self.sequence_parallel,
854
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
855
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
856
857
858
859
860
861
862
863
                    backend=self.backend,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
                    hidden_size,
                    self.weight_attr,
                    self.bias_attr,
864
                    parallel_mode=qkv_parallel_mode,
865
                    sequence_parallel=self.sequence_parallel,
866
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
867
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
868
869
870
871
                    backend=self.backend,
                )
            self.key_value = Linear(
                hidden_size,
Shijie's avatar
Shijie committed
872
                2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
873
874
                self.weight_attr,
                self.bias_attr,
875
                parallel_mode=qkv_parallel_mode,
876
                sequence_parallel=self.sequence_parallel,
877
                tp_group=self.tp_group,
Shijie's avatar
Shijie committed
878
                fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
879
880
881
882
883
                backend=self.backend,
            )

        # Attention.
        self.core_attention = DotProductAttention(
Shijie's avatar
Shijie committed
884
885
886
            self.num_attention_heads,
            self.hidden_size_per_attention_head,
            self.num_gqa_groups,
Shijie's avatar
Shijie committed
887
888
889
            attention_dropout,
            attn_mask_type=attn_mask_type,
            attention_type=self.attention_type,
Shijie's avatar
Shijie committed
890
            tp_size=self.tp_size,
Shijie's avatar
Shijie committed
891
892
893
894
895
896
897
898
899
            backend=self.backend,
        )

        # Linear
        self.proj = Linear(
            hidden_size,
            hidden_size,
            self.weight_attr,
            self.bias_attr,
900
            parallel_mode="row" if set_parallel_mode else None,
901
            sequence_parallel=self.sequence_parallel,
902
            tp_group=self.tp_group,
Shijie's avatar
Shijie committed
903
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
904
905
906
907
908
909
910
911
            backend=self.backend,
        )

    def forward(
        self,
        hidden_states: paddle.Tensor,
        attention_mask: Optional[paddle.Tensor] = None,
        encoder_output: Optional[paddle.Tensor] = None,
912
        rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
Shijie's avatar
Shijie committed
913
914
915
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
Tian Zheng's avatar
Tian Zheng committed
916
        recompute_core_attention: bool = False,
917
        is_first_microbatch: Optional[bool] = None,
Shijie's avatar
Shijie committed
918
919
920
921
922
923
924
925
926
927
928
929
    ) -> Tuple[Union[paddle.Tensor, None], ...]:
        """
        MultiHeadAttention Layer.

        Parameters
        ----------
        hidden_states : paddle.Tensor
                        Input tensor.
        attention_mask : Optional[paddle.Tensor], default = `None`
                        Boolean tensor used to mask out softmax input when not using attention.
        encoder_output : Optional[paddle.Tensor], default = `None`
                        Output of the encoder layer.
930
931
932
        rotary_pos_emb: Tuple[paddle.Tensor, paddle.Tensor], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
Shijie's avatar
Shijie committed
933
934
935
936
        core_attention_bias_type: str, default = `no_bias`
                                only support no_bias type currently, {`no_bias`}
        core_attention_bias: Optional[paddle.Tensor], default = `None`
                    Bias tensor for Q * K.T
937
        set_zero: bool, default = `True`
Shijie's avatar
Shijie committed
938
                    Whether to use the fast path to set output tensors to 0 or not.
Tian Zheng's avatar
Tian Zheng committed
939
940
941
942
943
        recompute_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
944
945
946
947
948
949
950
951
952
953
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
Shijie's avatar
Shijie committed
954
955
956
        """

        if self.attn_mask_type != "causal" and attention_mask is not None:
957
            assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor"
Shijie's avatar
Shijie committed
958

959
960
961
962
        input_dim = len(hidden_states.shape)
        if input_dim == 2:
            # hidden_states: [b * s_q, hidden_size]
            # need to get max_seq_len from attention_mask
963
964
            assert self.max_sequence_length is not None, "max_sequence_length must be provided"
            max_seq_len = self.max_sequence_length
965
966
967
968
969
970
        elif input_dim == 3:
            # hidden_states: [b, s_q, hidden_size]
            max_seq_len = hidden_states.shape[1]
        else:
            raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.")

Shijie's avatar
Shijie committed
971
972
        if self.attention_type == "self":
            if self.input_layernorm:
973
974
975
976
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
977
978
979
980
981
                if self.return_layernorm_output:
                    mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_qkv_layer = layernorm_qkv_outputs
            else:
982
983
984
985
                mixed_qkv_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
986

987
988
989
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
Shijie's avatar
Shijie committed
990
991

            # [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d]
992
993
994
995
996
997
998
999
1000
            mixed_qkv_layer = mixed_qkv_layer.reshape(
                shape=[
                    -1,
                    max_seq_len,
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
                    self.hidden_size_per_attention_head,
                ]
            )
1001

Shijie's avatar
Shijie committed
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
            # [b, s_q, (h/ng+2), ng, d]
            # --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d]
            query_layer, key_layer, value_layer = paddle.split(
                mixed_qkv_layer,
                num_or_sections=(num_queries_per_key_value, 1, 1),
                axis=2,
            )

            # query: -> [b, s, h, d]
            # key, value: -> [b, s, ng, d]
1012
1013
1014
1015
            query_layer, key_layer, value_layer = (
                x.reshape(shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head])
                for x in (query_layer, key_layer, value_layer)
            )
Shijie's avatar
Shijie committed
1016

1017
        else:  # cross attention
1018
1019
1020
1021
            mixed_kv_layer = self.key_value(
                encoder_output,
                is_first_microbatch=is_first_microbatch,
            )
Shijie's avatar
Shijie committed
1022
            # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
1023
1024
1025
1026
1027
1028
1029
1030
            mixed_kv_layer = mixed_kv_layer.reshape(
                shape=[
                    0,
                    0,
                    2 * self.num_gqa_groups_per_partition,
                    self.hidden_size_per_attention_head,
                ]
            )
Shijie's avatar
Shijie committed
1031

Shijie's avatar
Shijie committed
1032
1033
1034
1035
1036
1037
1038
1039
            # [b, s_kv, 2 * ng, head_size]
            # --> 2 [b, s_kv, ng, head_size]
            key_layer, value_layer = paddle.split(
                mixed_kv_layer,
                num_or_sections=2,
                axis=2,
            )

Shijie's avatar
Shijie committed
1040
            if self.input_layernorm:
1041
1042
1043
1044
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
1045
1046
1047
1048
1049
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
1050
1051
1052
1053
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
1054

Shijie's avatar
Shijie committed
1055
            # [b, s, hidden_size] --> [b, s, h, d]
1056
1057
1058
1059
1060
1061
1062
1063
            query_layer = query_layer.reshape(
                shape=[
                    -1,
                    max_seq_len,
                    self.num_attention_heads_per_partition,
                    self.hidden_size_per_attention_head,
                ]
            )
1064
1065
1066
1067

        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
            if fused_rotary_position_embedding is None:
1068
1069
1070
                query_layer, key_layer = apply_rotary_pos_emb(
                    query_layer, key_layer, q_pos_emb, k_pos_emb
                )
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
            else:
                query_layer, key_layer, _ = fused_rotary_position_embedding(
                    query_layer,
                    key_layer,
                    v=None,
                    sin=k_pos_emb,
                    cos=q_pos_emb,
                    position_ids=None,
                    use_neox_rotary_style=False,
                )

        with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
            if recompute_core_attention:
                context_layer = recompute(
                    self.core_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask,
                    core_attention_bias_type,
                    core_attention_bias,
                    set_zero,
                    use_reentrant=False,
                )
            else:
                context_layer = self.core_attention(
                    query_layer=query_layer,
                    key_layer=key_layer,
                    value_layer=value_layer,
                    attention_mask=attention_mask,
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    set_zero=set_zero,
                )
Shijie's avatar
Shijie committed
1105

1106
1107
        if input_dim == 3:
            context_layer = paddle.reshape(
1108
1109
1110
1111
1112
1113
                context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]]
            )
        else:  # input_dim == 2
            context_layer = paddle.reshape(
                context_layer, [-1, context_layer.shape[2] * context_layer.shape[3]]
            )
1114

Shijie's avatar
Shijie committed
1115
        # Output. [b, s, hidden]
1116
        attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch)
Shijie's avatar
Shijie committed
1117
1118
1119
1120

        if self.input_layernorm and self.return_layernorm_output:
            return attention_output, layernorm_output
        return attention_output