attention.py 44.1 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Shijie's avatar
Shijie committed
2
3
4
5
6
#
# See LICENSE for license information.
"""Attntion API"""

import math
7
import os
Shijie's avatar
Shijie committed
8
9
10
11
12
import warnings
from typing import Optional, Tuple, Union

import paddle
import paddle.nn.functional as F
13

14
15
16
17
try:
    from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
    fused_rotary_position_embedding = None
18
from transformer_engine import transformer_engine_paddle as tex
Shijie's avatar
Shijie committed
19

20
21
22
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .softmax import FusedScaleMaskSoftmax
23
24
25
26
27
28
29
30
from ..constants import (
    AttnTypes,
    TE_DType,
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
    dist_group_type,
)
31
from ..cpp_extensions import (
Shijie's avatar
Shijie committed
32
33
34
35
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
Shijie's avatar
Shijie committed
36
37
    fused_attn_fwd,
    fused_attn_bwd,
38
    mask_to_cu_seqlens,
Shijie's avatar
Shijie committed
39
)
40
from ..distributed import get_tp_group_and_world_size, track_rng_state
41
from ..utils import attention_mask_func, divide
Tian Zheng's avatar
Tian Zheng committed
42
from ..recompute import recompute
Shijie's avatar
Shijie committed
43

44
__all__ = ["DotProductAttention", "MultiHeadAttention", "RotaryPositionEmbedding"]
45
46


Shijie's avatar
Shijie committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
    """
    Used to repeat the key and value states for GQA.
    The hidden states go from (batch, seqlen, num_gqa_groups, head_size)
    to (batch, seqlen, num_heads, head_size)
    """
    batch, seqlen, num_gqa_groups, head_size = hidden_states.shape
    if n_rep == 1:
        return hidden_states

    hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
    return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size])


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class RotaryPositionEmbedding(paddle.nn.Layer):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """

    def __init__(
        self,
        dim: int,
        max_position_embeddings: int,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
        max_position_embeddings: int
            max_position_embeddings before position interpolation
        """
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
82
83
84
        self.inv_freq = 1.0 / (
            10000 ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / self.dim)
        )
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        self._set_cos_sin_cache(seq_len=max_position_embeddings)

    def _set_cos_sin_cache(self, seq_len):
        self.max_seq_len_cached = seq_len
        # [seq_len]
        t = paddle.arange(seq_len, dtype="float32")
        # [seq_len, dim/2]
        freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
        # [seq_len, dim]
        emb = paddle.concat([freqs, freqs], axis=-1)
        # [1, seqlen, 1, dim]
        self.cos_cached = emb.cos()[None, :, None, :]
        self.sin_cached = emb.sin()[None, :, None, :]

    def forward(self, max_seq_len: int):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        """
        cos = self.cos_cached[:, :, :max_seq_len, ...]
        sin = self.sin_cached[:, :, :max_seq_len, ...]
        return (cos, sin)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
115
116
117
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return paddle.concat([-x2, x1], axis=-1)  # shape is the same as x
118
119
120
121
122
123
124


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
    """Applies rotary positional embedding to the input."""

    if position_ids is None:
        # Note: Only for LlamaForCausalLMPipe model pretraining
125
126
        cos = cos[:, : q.shape[1], :, :]  # [bs, seq_len, 1, dim]
        sin = sin[:, : q.shape[1], :, :]  # [bs, seq_len, 1, dim]
127
    else:
128
129
130
131
        cos = cos.squeeze(axis=[0, 2])  # [seq_len, dim]
        sin = sin.squeeze(axis=[0, 2])  # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]
        sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]
132
133
134
135
136
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


Shijie's avatar
Shijie committed
137
138
139
140
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    def forward(
        ctx,
        qkv,
        cu_seqlens,
        attn_bias,
        max_seqlen,
        attn_scale,
        qkv_dtype,
        dropout_p,
        set_zero,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
        is_training,
155
        deterministic,
156
157
        fused_attention_backend,
    ):
Shijie's avatar
Shijie committed
158
        """Forward function for FusedAttention with packed QKV input"""
159
        out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
Shijie's avatar
Shijie committed
160
161
162
163
164
            qkv,
            cu_seqlens,
            is_training,
            max_seqlen,
            qkv_dtype,
165
            fused_attention_backend,
Shijie's avatar
Shijie committed
166
167
168
169
170
171
172
173
174
            attn_bias,
            attn_scale,
            dropout_p,
            set_zero,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
        )

175
        ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
176
177
178
179
180
181
182
183
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
184
        ctx.deterministic = deterministic
185
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
186
187
188
189
190
191

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed QKV input"""
192
        qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor()
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        dqkv, *rest = fused_attn_bwd_qkvpacked(
            qkv,
            cu_seqlens,
            rng_state,
            out,
            d_out,
            softmax_aux,
            ctx.fused_attention_backend,
            ctx.max_seqlen,
            ctx.qkv_dtype,
            ctx.attn_scale,
            ctx.dropout_p,
            ctx.set_zero,
            ctx.qkv_layout,
            ctx.attn_bias_type,
            ctx.attn_mask_type,
209
            ctx.deterministic,
210
        )
Shijie's avatar
Shijie committed
211
212
213

        # if no_bias, return dqkv
        if ctx.attn_bias_type == "no_bias":
214
            return (dqkv, None)
Shijie's avatar
Shijie committed
215
        # else, return (dqkv, dbias)
216
        return (dqkv, None, rest[0])
Shijie's avatar
Shijie committed
217
218
219
220
221
222


class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    def forward(
        ctx,
        q,
        kv,
        cu_seqlens_q,
        cu_seqlens_kv,
        attn_bias,
        max_seqlen_q,
        max_seqlen_kv,
        attn_scale,
        qkv_dtype,
        dropout_p,
        set_zero,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
        is_training,
240
        deterministic,
241
242
        fused_attention_backend,
    ):
Shijie's avatar
Shijie committed
243
        """Forward function for FusedAttention with packed KV input"""
244
        out, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
            q,
            kv,
            cu_seqlens_q,
            cu_seqlens_kv,
            is_training,
            max_seqlen_q,
            max_seqlen_kv,
            qkv_dtype,
            fused_attention_backend,
            attn_bias,
            attn_scale,
            dropout_p,
            set_zero,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
        )
Shijie's avatar
Shijie committed
262

263
        ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
264
265
266
267
268
269
270
271
272
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
273
        ctx.deterministic = deterministic
274
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
275
276
277
278
279
280

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed KV input"""
281
        q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        dq, dkv, *rest = fused_attn_bwd_kvpacked(
            q,
            kv,
            cu_seqlens_q,
            cu_seqlens_kv,
            rng_state,
            out,
            d_out,
            softmax_aux,
            ctx.fused_attention_backend,
            ctx.max_seqlen_q,
            ctx.max_seqlen_kv,
            ctx.qkv_dtype,
            ctx.attn_scale,
            ctx.dropout_p,
            ctx.set_zero,
            ctx.qkv_layout,
            ctx.attn_bias_type,
            ctx.attn_mask_type,
301
            ctx.deterministic,
302
        )
Shijie's avatar
Shijie committed
303
304
305

        # if no_bias, return dq, dkv
        if ctx.attn_bias_type == "no_bias":
306
            return (dq, dkv, None, None)
Shijie's avatar
Shijie committed
307
        # else, return (dq, dkv, dbias)
308
        return (dq, dkv, None, None, rest[0])
Shijie's avatar
Shijie committed
309
310


Shijie's avatar
Shijie committed
311
312
313
314
class FusedAttnFunc(paddle.autograd.PyLayer):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    def forward(
        ctx,
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_kv,
        attn_bias,
        max_seqlen_q,
        max_seqlen_kv,
        attn_scale,
        qkv_dtype,
        dropout_p,
        set_zero,
        qkv_layout,
        attn_bias_type,
        attn_mask_type,
        is_training,
333
        deterministic,
334
335
        fused_attention_backend,
    ):
Shijie's avatar
Shijie committed
336
        """Forward function for FusedAttention with separate Q, K, V tensors"""
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        out, softmax_aux, rng_state = fused_attn_fwd(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_kv,
            is_training,
            max_seqlen_q,
            max_seqlen_kv,
            qkv_dtype,
            fused_attention_backend,
            attn_bias,
            attn_scale,
            dropout_p,
            set_zero,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
        )
Shijie's avatar
Shijie committed
356
357
358
359
360
361
362
363
364
365
366

        ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
367
        ctx.deterministic = deterministic
Shijie's avatar
Shijie committed
368
369
370
371
372
373
374
375
        ctx.fused_attention_backend = fused_attention_backend

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with separate Q, K, V tensors"""
        q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        dq, dk, dv, *rest = fused_attn_bwd(
            q,
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_kv,
            rng_state,
            out,
            d_out,
            softmax_aux,
            ctx.fused_attention_backend,
            ctx.max_seqlen_q,
            ctx.max_seqlen_kv,
            ctx.qkv_dtype,
            ctx.attn_scale,
            ctx.dropout_p,
            ctx.set_zero,
            ctx.qkv_layout,
            ctx.attn_bias_type,
            ctx.attn_mask_type,
396
            ctx.deterministic,
397
        )
Shijie's avatar
Shijie committed
398
399
400
401
402
403
404
        # if no_bias, return dq, dk, dv
        if ctx.attn_bias_type == "no_bias":
            return (dq, dk, dv, None, None)
        # else, return (dq, dk, dv, dbias)
        return (dq, dk, dv, None, None, rest[0])


Shijie's avatar
Shijie committed
405
class DotProductAttention(paddle.nn.Layer):
406
    """
Shijie's avatar
Shijie committed
407
408
409
410
411
412
413
414
415
    Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

        Argument :attr:`attention_mask` will be ignored in the `forward` call when
        :attr:`attn_mask_type` is set to `"causal"`.

416
417
418
419
420
421
    .. warning::

        Fused attention backward uses a non-deterministic algorithm when workspace
        optimization is not enabled. To use a deterministic algorithm, set the
        environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`

Shijie's avatar
Shijie committed
422
423
    Parameters
    ----------
Shijie's avatar
Shijie committed
424
425
426
427
428
429
430
431
432
433
434
435
    num_attention_heads: int
            number of attention heads in the transformer layer.
    kv_channels: int
            number of channels in the key and value tensors.
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Shijie's avatar
Shijie committed
436
437
438
439
440
441
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
Shijie's avatar
Shijie committed
442
443
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
Shijie's avatar
Shijie committed
444
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
445
             backend to use for attention operation.
Shijie's avatar
Shijie committed
446
447
    """

448
449
450
451
452
453
454
455
456
457
458
    def __init__(
        self,
        num_attention_heads: int,
        kv_channels: int,
        num_gqa_groups: Optional[int] = None,
        attention_dropout: float = 0.1,
        attn_mask_type: str = "causal",
        attention_type: str = "self",
        tp_size: int = 1,
        backend: str = "transformer_engine",
    ) -> None:
Shijie's avatar
Shijie committed
459
460
461
462
463
        super().__init__()

        self.attn_mask_type = attn_mask_type
        self.attention_dropout = attention_dropout
        self.attention_type = attention_type
Shijie's avatar
Shijie committed
464
465
466
467
        self.qkv_layout = "bshd_bshd_bshd"
        self.hidden_size_per_attention_head = kv_channels
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        self.tp_size = tp_size
468
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
Shijie's avatar
Shijie committed
469
470
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
        self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups
471
472
473

        self.backend = backend

Tim Moon's avatar
Tim Moon committed
474
        self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1")))
475

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
        self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))

        # To use the workspace optimization path for determinism, please
        # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0,
        # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0.
        cudnn_version = paddle.get_cudnn_version()
        if 8905 <= cudnn_version < 9000:
            if self.deterministic:
                # workspace optimization path is deterministic
                os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"

            # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
            # - unset:       enables workspace optimization when required workspace is <= 256MB
            #                or when bias gradient needs to be computed
            # - n:           enables workspace optimization when required workspace is <= n bytes
            # - -1:          enables workspace optimization always
            # - 0:           disables workspace optimization always
            if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ:
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0"
                if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
                    os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"

499
        if not self.use_fused_attention and backend == "transformer_engine":
500
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
501
            self.backend = "paddle"
502

503
504
505
506
        if self.backend != "transformer_engine":
            self.scale_mask_softmax = FusedScaleMaskSoftmax(
                attn_mask_type, attention_mask_func, backend=self.backend
            )
Shijie's avatar
Shijie committed
507
508
509
510

    def forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
511
512
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

            Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
            is set to `"causal"`.


        Parameters
        ----------
        query_layer : paddle.Tensor
530
                      Query tensor.
Shijie's avatar
Shijie committed
531
532
533
534
        key_layer : paddle.Tensor
                      Key tensor.
        value_layer : paddle.Tensor
                      Value tensor.
Shijie's avatar
Shijie committed
535
        attention_mask : Optional[paddle.Tensor], default = `None`
536
                         Boolean tensor used to mask out softmax input when not using attention.
Shijie's avatar
Shijie committed
537
        core_attention_bias_type: str, default = `no_bias`
538
                                  only support no_bias type currently, {`no_bias`}
Shijie's avatar
Shijie committed
539
        core_attention_bias: Optional[paddle.Tensor], default = `None`
540
541
542
                             Bias tensor for Q * K.T
        set_zero: bool, default = `True`
                  Whether to use the fast path to set output tensors to 0 or not.
Shijie's avatar
Shijie committed
543
544
        """

Tim Moon's avatar
Tim Moon committed
545
546
        backend = self.backend

547
548
549
550
        assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!"
        assert (
            key_layer.shape[-2] == self.num_gqa_groups_per_partition
        ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
Shijie's avatar
Shijie committed
551

552
        if backend == "transformer_engine":
553
            max_s_q = query_layer.shape[1]
Shijie's avatar
Shijie committed
554
            max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1]
555
            self.fused_attention_backend = tex.get_fused_attn_backend(
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
                TE_DType[query_layer.dtype],
                TE_DType[query_layer.dtype],
                tex.get_nvte_qkv_layout(self.qkv_layout),
                AttnBiasType[core_attention_bias_type],
                AttnMaskType[self.attn_mask_type],
                self.attention_dropout,
                query_layer.shape[-2],
                key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2],
                max_s_q,
                max_s_kv,
                query_layer.shape[-1],
            )

            is_backend_avail = self.fused_attention_backend in [
                FusedAttnBackend["F16_max512_seqlen"],
                FusedAttnBackend["F16_arbitrary_seqlen"],
            ]
573
            if is_backend_avail and self.use_fused_attention:
574
575
576
577
578
579
580
581
582
                return self._te_forward(
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask,
                    core_attention_bias_type,
                    core_attention_bias,
                    set_zero,
                )
583
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
584
585
586
587
588
            backend = "paddle"
            self.scale_mask_softmax = FusedScaleMaskSoftmax(
                self.attn_mask_type, attention_mask_func, backend=backend
            )
        if backend == "paddle":
Shijie's avatar
Shijie committed
589
            if core_attention_bias_type != "no_bias":
590
591
592
593
                warnings.warn(
                    "Paddle backend dot product attention does not support bias yet. "
                    "Bias will be ignored."
                )
Shijie's avatar
Shijie committed
594
            return self._pd_forward(query_layer, key_layer, value_layer, attention_mask)
Tim Moon's avatar
Tim Moon committed
595
        raise AttributeError(f"Backend {backend} is not supported.")
Shijie's avatar
Shijie committed
596
597
598
599

    def _te_forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
600
601
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
602
603
604
605
606
607
608
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:

        if self.attention_type == "self":
Shijie's avatar
Shijie committed
609
            # self attention - q: [b, s, h, d]  kv: None
610
611
612
613
614
            assert (
                len(query_layer.shape) == 4
                and len(key_layer.shape) == 4
                and len(value_layer.shape) == 4
            ), "q,k,v shape must be [b, s, h, d] for dot product self attention"
Shijie's avatar
Shijie committed
615
            max_seqlen = query_layer.shape[1]
616
            if self.attn_mask_type == "causal" or attention_mask is None:
617
618
619
620
621
622
                cu_seqlens = paddle.arange(
                    0,
                    (query_layer.shape[0] + 1) * query_layer.shape[1],
                    step=query_layer.shape[1],
                    dtype="int32",
                )
623
624
            else:
                cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False)
Shijie's avatar
Shijie committed
625
            qkv_dtype = TE_DType[query_layer.dtype]
626

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
            output = FusedAttnFunc.apply(
                query_layer,
                key_layer,
                value_layer,
                cu_seqlens,
                cu_seqlens,
                core_attention_bias,
                max_seqlen,
                max_seqlen,
                1.0 / self.norm_factor,
                qkv_dtype,
                self.attention_dropout if self.training else 0.0,
                set_zero,
                self.qkv_layout,
                core_attention_bias_type,
                self.attn_mask_type,
                self.training,
644
                self.deterministic,
645
646
                self.fused_attention_backend,
            )
Shijie's avatar
Shijie committed
647
        elif self.attention_type == "cross":
Shijie's avatar
Shijie committed
648
            # cross attention - q: [b, s_q, h, d]  k,v: [b, s_kv, h, d]
Shijie's avatar
Shijie committed
649
            assert (
650
651
                len(query_layer.shape) == 4
                and len(key_layer.shape) == 4
Shijie's avatar
Shijie committed
652
                and len(value_layer.shape) == 4
653
654
            ), (
                "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]"
Shijie's avatar
Shijie committed
655
                "for dot product cross attention"
656
657
            )
            assert attention_mask is not None, "attention_mask must be provided for cross attention"
Shijie's avatar
Shijie committed
658
            max_seqlen_q = query_layer.shape[1]
Shijie's avatar
Shijie committed
659
            max_seqlen_kv = key_layer.shape[1]
Shijie's avatar
Shijie committed
660
661
            cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True)
            qkv_dtype = TE_DType[query_layer.dtype]
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
            output = FusedAttnFunc.apply(
                query_layer,
                key_layer,
                value_layer,
                cu_seqlens_q,
                cu_seqlens_kv,
                core_attention_bias,
                max_seqlen_q,
                max_seqlen_kv,
                1.0 / self.norm_factor,
                qkv_dtype,
                self.attention_dropout if self.training else 0.0,
                set_zero,
                self.qkv_layout,
                core_attention_bias_type,
                self.attn_mask_type,
                self.training,
679
                self.deterministic,
680
681
                self.fused_attention_backend,
            )
Shijie's avatar
Shijie committed
682
683
684
685
686
687
688
        else:
            raise ValueError("attention_type must be one of ['self', 'cross']")
        return output

    def _pd_forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
689
690
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
691
692
        attention_mask: Optional[paddle.Tensor] = None,
    ) -> paddle.Tensor:
Shijie's avatar
Shijie committed
693
694
695
696

        q = query_layer
        k = repeat_kv(key_layer, self.num_queries_per_key_value)
        v = repeat_kv(value_layer, self.num_queries_per_key_value)
Shijie's avatar
Shijie committed
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712

        q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
        k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
        v = paddle.transpose(x=v, perm=[0, 2, 1, 3])

        product = paddle.matmul(x=q * (1.0 / self.norm_factor), y=k, transpose_y=True)
        attention_probs = self.scale_mask_softmax(product, attention_mask, scale=None)

        if self.attention_dropout > 0:
            attention_probs = F.dropout(
                attention_probs,
                self.attention_dropout,
                training=self.training,
            )

        out = paddle.matmul(attention_probs, v)
713
        out = paddle.transpose(out, perm=[0, 2, 1, 3])  # [b, s, h, d]
Shijie's avatar
Shijie committed
714
715
716
717
        # out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
        return out


718
class MultiHeadAttention(paddle.nn.Layer):
719
720
721
    """
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.
Shijie's avatar
Shijie committed
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746

    Parameters
    ----------
    hidden_size: int
                    hidden size of the model.
    num_attention_heads: int
                    number of attention heads.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon: float, default = 1e-5
                          epsilon to use in the layer norm operations.
    weight_attr: Union[paddle.ParamAttr, None], default = `None`
                    paddle.ParamAttr object for the weight parameter.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = `None`
                    paddle.ParamAttr object for the bias parameter.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    params_dtype: Optional[paddle.dtype], default = `None`
                    data type for the weights and biases.
    return_layernorm_output: bool, default = `False`
                    whether to return the output of the layernorm operation.
    input_layernorm: bool, default = `False`
                    whether to apply layernorm to the input.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
747
748
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
Shijie's avatar
Shijie committed
749
750
751
    zero_centered_gamma: bool, default = `False`
                    whether to zero initialize the gamma of the layernorm operation.
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
752
753
             backend to use for attention operation. If set to 'paddle', a framework
             only no-FP8 path is executed with limited optimization.
Tian Zheng's avatar
Tian Zheng committed
754
755
756
757
758
759
760

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
761
762
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
Tian Zheng's avatar
Tian Zheng committed
763
764
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
Shijie's avatar
Shijie committed
765
766
767
768
769
770
771
772
    num_gqa_groups : int, default = `None`
                     number of GQA groups in the transformer layer.
                     Grouped Query Attention is described in
                     `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                     This only affects the keys and values, not the querys.
                     GQA-1 is equivalent to Multi-Query Attention
                     (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                     is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Tian Zheng's avatar
Tian Zheng committed
773
774
775
776
777
778
779
    rng_state_name : str, default = `local_seed`
                   Controls the rng state used for dropout on attention probs. The
                   specified rng should be set different seeds for different TP ranks.
                   It will be ignored if `set_parallel_mode` is False. The specified
                   name should be registered through
                   `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
                   .add(rng_state_name, seed)`.
Shijie's avatar
Shijie committed
780
781
782
783
784
785
786
787
788
789

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.

Shijie's avatar
Shijie committed
790
791
792
793
794
795
796
797
798
799
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
800
        max_sequence_length: Optional[int] = None,
Shijie's avatar
Shijie committed
801
802
803
804
805
        attn_mask_type: str = "causal",
        params_dtype: Optional[paddle.dtype] = None,
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
806
        normalization: str = "LayerNorm",
Shijie's avatar
Shijie committed
807
        zero_centered_gamma: bool = False,
808
        set_parallel_mode: bool = False,
809
        sequence_parallel: bool = False,
810
        tp_group: Optional[dist_group_type] = None,
Shijie's avatar
Shijie committed
811
        num_gqa_groups: Optional[int] = None,
Shijie's avatar
Shijie committed
812
        fuse_wgrad_accumulation: bool = False,
813
814
        rng_state_name: str = "local_seed",
        backend: str = "transformer_engine",
Shijie's avatar
Shijie committed
815
816
817
818
819
820
    ) -> None:
        super().__init__()
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.return_layernorm_output = return_layernorm_output
        self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
821
        self.max_sequence_length = max_sequence_length
Shijie's avatar
Shijie committed
822
823
824
825
826
827
        self.weight_attr = weight_attr
        self.bias_attr = bias_attr
        self.attn_mask_type = attn_mask_type

        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"

828
829
830
        self.tp_group, self.tp_size = get_tp_group_and_world_size(
            tp_group, enable_tp=set_parallel_mode
        )
831
        self.tensor_parallel = self.tp_size > 1
832
        self.sequence_parallel = self.tensor_parallel and sequence_parallel
Shijie's avatar
Shijie committed
833
834
        self.hidden_size_per_attention_head = hidden_size // num_attention_heads
        self.num_attention_heads = num_attention_heads
835
        self.set_parallel_mode = set_parallel_mode
836
        self.rng_state_name = rng_state_name
Shijie's avatar
Shijie committed
837
838
        self.backend = backend

839
        self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
840
841
842
843
844
845
846
        self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
        assert (
            self.num_attention_heads % self.num_gqa_groups == 0
        ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (
            self.num_gqa_groups % self.tp_size == 0
        ), "The number of GQA groups must be divisible by tensor parallel size!"
Shijie's avatar
Shijie committed
847
848
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
        self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads)
849
850
        qkv_parallel_mode = "column" if set_parallel_mode else None

Shijie's avatar
Shijie committed
851
852
853
854
        if self.attention_type == "self":
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
Shijie's avatar
Shijie committed
855
                    hidden_size + 2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
856
857
858
859
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
860
                    normalization=normalization,
Shijie's avatar
Shijie committed
861
                    zero_centered_gamma=zero_centered_gamma,
862
                    parallel_mode=qkv_parallel_mode,
863
                    sequence_parallel=self.sequence_parallel,
864
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
865
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
866
867
868
869
870
                    backend=self.backend,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
Shijie's avatar
Shijie committed
871
                    hidden_size + 2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
872
873
                    self.weight_attr,
                    self.bias_attr,
874
                    parallel_mode=qkv_parallel_mode,
875
                    sequence_parallel=self.sequence_parallel,
876
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
877
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
878
879
880
                    backend=self.backend,
                )

881
        else:  # cross attention
Shijie's avatar
Shijie committed
882
883
884
885
886
887
888
889
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
                    hidden_size,
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
890
                    normalization=normalization,
Shijie's avatar
Shijie committed
891
                    zero_centered_gamma=zero_centered_gamma,
892
                    parallel_mode=qkv_parallel_mode,
893
                    sequence_parallel=self.sequence_parallel,
894
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
895
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
896
897
898
899
900
901
902
903
                    backend=self.backend,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
                    hidden_size,
                    self.weight_attr,
                    self.bias_attr,
904
                    parallel_mode=qkv_parallel_mode,
905
                    sequence_parallel=self.sequence_parallel,
906
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
907
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
908
909
910
911
                    backend=self.backend,
                )
            self.key_value = Linear(
                hidden_size,
Shijie's avatar
Shijie committed
912
                2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
913
914
                self.weight_attr,
                self.bias_attr,
915
                parallel_mode=qkv_parallel_mode,
916
                sequence_parallel=self.sequence_parallel,
917
                tp_group=self.tp_group,
Shijie's avatar
Shijie committed
918
                fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
919
920
921
922
923
                backend=self.backend,
            )

        # Attention.
        self.core_attention = DotProductAttention(
Shijie's avatar
Shijie committed
924
925
926
            self.num_attention_heads,
            self.hidden_size_per_attention_head,
            self.num_gqa_groups,
Shijie's avatar
Shijie committed
927
928
929
            attention_dropout,
            attn_mask_type=attn_mask_type,
            attention_type=self.attention_type,
Shijie's avatar
Shijie committed
930
            tp_size=self.tp_size,
Shijie's avatar
Shijie committed
931
932
933
934
935
936
937
938
939
            backend=self.backend,
        )

        # Linear
        self.proj = Linear(
            hidden_size,
            hidden_size,
            self.weight_attr,
            self.bias_attr,
940
            parallel_mode="row" if set_parallel_mode else None,
941
            sequence_parallel=self.sequence_parallel,
942
            tp_group=self.tp_group,
Shijie's avatar
Shijie committed
943
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
944
945
946
947
948
949
950
951
            backend=self.backend,
        )

    def forward(
        self,
        hidden_states: paddle.Tensor,
        attention_mask: Optional[paddle.Tensor] = None,
        encoder_output: Optional[paddle.Tensor] = None,
952
        rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
Shijie's avatar
Shijie committed
953
954
955
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
Tian Zheng's avatar
Tian Zheng committed
956
        recompute_core_attention: bool = False,
957
        is_first_microbatch: Optional[bool] = None,
Shijie's avatar
Shijie committed
958
959
960
961
962
963
964
965
966
967
968
969
    ) -> Tuple[Union[paddle.Tensor, None], ...]:
        """
        MultiHeadAttention Layer.

        Parameters
        ----------
        hidden_states : paddle.Tensor
                        Input tensor.
        attention_mask : Optional[paddle.Tensor], default = `None`
                        Boolean tensor used to mask out softmax input when not using attention.
        encoder_output : Optional[paddle.Tensor], default = `None`
                        Output of the encoder layer.
970
971
972
        rotary_pos_emb: Tuple[paddle.Tensor, paddle.Tensor], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
Shijie's avatar
Shijie committed
973
974
975
976
        core_attention_bias_type: str, default = `no_bias`
                                only support no_bias type currently, {`no_bias`}
        core_attention_bias: Optional[paddle.Tensor], default = `None`
                    Bias tensor for Q * K.T
977
        set_zero: bool, default = `True`
Shijie's avatar
Shijie committed
978
                    Whether to use the fast path to set output tensors to 0 or not.
Tian Zheng's avatar
Tian Zheng committed
979
980
981
982
983
        recompute_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
984
985
986
987
988
989
990
991
992
993
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
Shijie's avatar
Shijie committed
994
995
996
        """

        if self.attn_mask_type != "causal" and attention_mask is not None:
997
            assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor"
Shijie's avatar
Shijie committed
998

999
1000
1001
1002
        input_dim = len(hidden_states.shape)
        if input_dim == 2:
            # hidden_states: [b * s_q, hidden_size]
            # need to get max_seq_len from attention_mask
1003
1004
            assert self.max_sequence_length is not None, "max_sequence_length must be provided"
            max_seq_len = self.max_sequence_length
1005
1006
1007
1008
1009
1010
        elif input_dim == 3:
            # hidden_states: [b, s_q, hidden_size]
            max_seq_len = hidden_states.shape[1]
        else:
            raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.")

Shijie's avatar
Shijie committed
1011
1012
        if self.attention_type == "self":
            if self.input_layernorm:
1013
1014
1015
1016
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
1017
1018
1019
1020
1021
                if self.return_layernorm_output:
                    mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_qkv_layer = layernorm_qkv_outputs
            else:
1022
1023
1024
1025
                mixed_qkv_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
1026

1027
1028
1029
            num_queries_per_key_value = (
                self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition
            )
Shijie's avatar
Shijie committed
1030
1031

            # [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d]
1032
1033
1034
1035
1036
1037
1038
1039
1040
            mixed_qkv_layer = mixed_qkv_layer.reshape(
                shape=[
                    -1,
                    max_seq_len,
                    (num_queries_per_key_value + 2),
                    self.num_gqa_groups_per_partition,
                    self.hidden_size_per_attention_head,
                ]
            )
1041

Shijie's avatar
Shijie committed
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
            # [b, s_q, (h/ng+2), ng, d]
            # --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d]
            query_layer, key_layer, value_layer = paddle.split(
                mixed_qkv_layer,
                num_or_sections=(num_queries_per_key_value, 1, 1),
                axis=2,
            )

            # query: -> [b, s, h, d]
            # key, value: -> [b, s, ng, d]
1052
1053
1054
1055
            query_layer, key_layer, value_layer = (
                x.reshape(shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head])
                for x in (query_layer, key_layer, value_layer)
            )
Shijie's avatar
Shijie committed
1056

1057
        else:  # cross attention
1058
1059
1060
1061
            mixed_kv_layer = self.key_value(
                encoder_output,
                is_first_microbatch=is_first_microbatch,
            )
Shijie's avatar
Shijie committed
1062
            # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
1063
1064
1065
1066
1067
1068
1069
1070
            mixed_kv_layer = mixed_kv_layer.reshape(
                shape=[
                    0,
                    0,
                    2 * self.num_gqa_groups_per_partition,
                    self.hidden_size_per_attention_head,
                ]
            )
Shijie's avatar
Shijie committed
1071

Shijie's avatar
Shijie committed
1072
1073
1074
1075
1076
1077
1078
1079
            # [b, s_kv, 2 * ng, head_size]
            # --> 2 [b, s_kv, ng, head_size]
            key_layer, value_layer = paddle.split(
                mixed_kv_layer,
                num_or_sections=2,
                axis=2,
            )

Shijie's avatar
Shijie committed
1080
            if self.input_layernorm:
1081
1082
1083
1084
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
1085
1086
1087
1088
1089
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
1090
1091
1092
1093
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
1094

Shijie's avatar
Shijie committed
1095
            # [b, s, hidden_size] --> [b, s, h, d]
1096
1097
1098
1099
1100
1101
1102
1103
            query_layer = query_layer.reshape(
                shape=[
                    -1,
                    max_seq_len,
                    self.num_attention_heads_per_partition,
                    self.hidden_size_per_attention_head,
                ]
            )
1104
1105
1106
1107

        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
            if fused_rotary_position_embedding is None:
1108
1109
1110
                query_layer, key_layer = apply_rotary_pos_emb(
                    query_layer, key_layer, q_pos_emb, k_pos_emb
                )
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
            else:
                query_layer, key_layer, _ = fused_rotary_position_embedding(
                    query_layer,
                    key_layer,
                    v=None,
                    sin=k_pos_emb,
                    cos=q_pos_emb,
                    position_ids=None,
                    use_neox_rotary_style=False,
                )

        with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
            if recompute_core_attention:
                context_layer = recompute(
                    self.core_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask,
                    core_attention_bias_type,
                    core_attention_bias,
                    set_zero,
                    use_reentrant=False,
                )
            else:
                context_layer = self.core_attention(
                    query_layer=query_layer,
                    key_layer=key_layer,
                    value_layer=value_layer,
                    attention_mask=attention_mask,
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    set_zero=set_zero,
                )
Shijie's avatar
Shijie committed
1145

1146
1147
        if input_dim == 3:
            context_layer = paddle.reshape(
1148
1149
1150
1151
1152
1153
                context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]]
            )
        else:  # input_dim == 2
            context_layer = paddle.reshape(
                context_layer, [-1, context_layer.shape[2] * context_layer.shape[3]]
            )
1154

Shijie's avatar
Shijie committed
1155
        # Output. [b, s, hidden]
1156
        attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch)
Shijie's avatar
Shijie committed
1157
1158
1159
1160

        if self.input_layernorm and self.return_layernorm_output:
            return attention_output, layernorm_output
        return attention_output