attention.py 29.3 KB
Newer Older
Shijie's avatar
Shijie committed
1
2
3
4
5
6
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Attntion API"""

import math
7
import os
Shijie's avatar
Shijie committed
8
9
10
11
12
import warnings
from typing import Optional, Tuple, Union

import paddle
import paddle.nn.functional as F
13
import transformer_engine_paddle as tex
Shijie's avatar
Shijie committed
14

15
16
17
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .softmax import FusedScaleMaskSoftmax
18
19
from ..constants import (AttnTypes, TE_DType, AttnBiasType, AttnMaskType, FusedAttnBackend,
                         dist_group_type)
20
from ..cpp_extensions import (
Shijie's avatar
Shijie committed
21
22
23
24
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
25
    mask_to_cu_seqlens,
Shijie's avatar
Shijie committed
26
)
27
from ..distributed import get_tp_group_and_world_size, track_rng_state
28
from ..utils import attention_mask_func, divide
Tian Zheng's avatar
Tian Zheng committed
29
from ..recompute import recompute
Shijie's avatar
Shijie committed
30

31
32
33
__all__ = ["DotProductAttention", "MultiHeadAttention"]


Shijie's avatar
Shijie committed
34
35
36
37
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
38
39
40
    def forward(ctx, qkv, cu_seqlens, attn_bias, max_seqlen, attn_scale, qkv_dtype, dropout_p,
                set_zero, qkv_layout, attn_bias_type, attn_mask_type, is_training,
                fused_attention_backend):
Shijie's avatar
Shijie committed
41
        """Forward function for FusedAttention with packed QKV input"""
42
        out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
Shijie's avatar
Shijie committed
43
44
45
46
47
            qkv,
            cu_seqlens,
            is_training,
            max_seqlen,
            qkv_dtype,
48
            fused_attention_backend,
Shijie's avatar
Shijie committed
49
50
51
52
53
54
55
56
57
            attn_bias,
            attn_scale,
            dropout_p,
            set_zero,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
        )

58
        ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
59
60
61
62
63
64
65
66
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
67
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
68
69
70
71
72
73

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed QKV input"""
74
75
76
77
78
        qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor()
        dqkv, *rest = fused_attn_bwd_qkvpacked(qkv, cu_seqlens, rng_state, out, d_out, softmax_aux,
                                               ctx.fused_attention_backend, ctx.max_seqlen,
                                               ctx.qkv_dtype, ctx.attn_scale, ctx.dropout_p,
                                               ctx.set_zero, ctx.qkv_layout, ctx.attn_bias_type,
Shijie's avatar
Shijie committed
79
80
81
82
                                               ctx.attn_mask_type)

        # if no_bias, return dqkv
        if ctx.attn_bias_type == "no_bias":
83
            return (dqkv, None)
Shijie's avatar
Shijie committed
84
        # else, return (dqkv, dbias)
85
        return (dqkv, None, rest[0])
Shijie's avatar
Shijie committed
86
87
88
89
90
91


class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
92
93
94
    def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
                attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
                attn_mask_type, is_training, fused_attention_backend):
Shijie's avatar
Shijie committed
95
        """Forward function for FusedAttention with packed KV input"""
96
97
98
99
        out, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
            q, kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, qkv_dtype,
            fused_attention_backend, attn_bias, attn_scale, dropout_p, set_zero, qkv_layout,
            attn_bias_type, attn_mask_type)
Shijie's avatar
Shijie committed
100

101
        ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
102
103
104
105
106
107
108
109
110
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
111
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
112
113
114
115
116
117

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed KV input"""
118
        q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
Shijie's avatar
Shijie committed
119
        dq, dkv, *rest = fused_attn_bwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
120
121
122
123
124
                                                 d_out, softmax_aux, ctx.fused_attention_backend,
                                                 ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
                                                 ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
                                                 ctx.qkv_layout, ctx.attn_bias_type,
                                                 ctx.attn_mask_type)
Shijie's avatar
Shijie committed
125
126
127

        # if no_bias, return dq, dkv
        if ctx.attn_bias_type == "no_bias":
128
            return (dq, dkv, None, None)
Shijie's avatar
Shijie committed
129
        # else, return (dq, dkv, dbias)
130
        return (dq, dkv, None, None, rest[0])
Shijie's avatar
Shijie committed
131
132
133


class DotProductAttention(paddle.nn.Layer):
134
    """
Shijie's avatar
Shijie committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

        Argument :attr:`attention_mask` will be ignored in the `forward` call when
        :attr:`attn_mask_type` is set to `"causal"`.

    Parameters
    ----------
    norm_factor : float
                    normalization factor for the attention scores.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
155
             backend to use for attention operation.
Shijie's avatar
Shijie committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    """

    def __init__(self,
                 norm_factor: float,
                 attention_dropout: float = 0.1,
                 attn_mask_type: str = "causal",
                 attention_type: str = "self",
                 backend: str = 'transformer_engine') -> None:
        super().__init__()

        self.norm_factor = norm_factor
        self.attn_mask_type = attn_mask_type
        self.attention_dropout = attention_dropout
        self.attention_type = attention_type
170
        self.qkv_layout = "bs3hd" if attention_type == "self" else "bshd_bs2hd"
171
172
173

        self.backend = backend

Tim Moon's avatar
Tim Moon committed
174
        self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1")))
175

176
177
        if not self.use_fused_attention and backend == 'transformer_engine':
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
178
179
            self.backend = 'paddle'

Shijie's avatar
Shijie committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        if self.backend != 'transformer_engine':
            self.scale_mask_softmax = FusedScaleMaskSoftmax(attn_mask_type,
                                                            attention_mask_func,
                                                            backend=self.backend)

    def forward(
        self,
        query_layer: paddle.Tensor,
        key_value_layer: paddle.Tensor = None,
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

            Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
            is set to `"causal"`.

        .. note::

            For self attention, :attr:`query_layer` is the `[query, key, value]` tensor
            stacked along the 2nd dimension, which must be of shape (:attr:`batch_size`,
            :attr:`seq_length`, 3, :attr:`num_attention_heads`, :attr:`size_per_head`).
            And :attr:`key_value_layer` is `None`.
            For cross attention, :attr:`query_layer` is the `[query]` tensor, which must
            be of shape (:attr:`batch_size`, :attr:`seq_length`, :attr:`num_attention_heads`,
            :attr:`size_per_head`). And :attr:`key_value_layer` is the `[key, value]` tensor,
            which must be of shape (:attr:`batch_size`, :attr:`seq_length`, 2,
            :attr:`num_attention_heads`, :attr:`size_per_head`).



        Parameters
        ----------
        query_layer : paddle.Tensor
219
                      Query tensor.
Shijie's avatar
Shijie committed
220
        key_value_layer : paddle.Tensor
221
                          Key tensor.
Shijie's avatar
Shijie committed
222
        attention_mask : Optional[paddle.Tensor], default = `None`
223
                         Boolean tensor used to mask out softmax input when not using attention.
Shijie's avatar
Shijie committed
224
        core_attention_bias_type: str, default = `no_bias`
225
                                  only support no_bias type currently, {`no_bias`}
Shijie's avatar
Shijie committed
226
        core_attention_bias: Optional[paddle.Tensor], default = `None`
227
228
229
                             Bias tensor for Q * K.T
        set_zero: bool, default = `True`
                  Whether to use the fast path to set output tensors to 0 or not.
Shijie's avatar
Shijie committed
230
231
        """

Tim Moon's avatar
Tim Moon committed
232
233
234
        backend = self.backend

        if backend == 'transformer_engine':
235
236
237
238
            max_s_q = query_layer.shape[1]
            max_s_kv = max_s_q if self.attention_type == "self" else key_value_layer.shape[1]
            self.fused_attention_backend = tex.get_fused_attn_backend(
                TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
239
                tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
240
241
242
243
244
245
246
247
248
249
                AttnMaskType[self.attn_mask_type], self.attention_dropout, max_s_q, max_s_kv,
                query_layer.shape[-1])

            is_backend_avail = (self.fused_attention_backend in [
                FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]
            ])
            if is_backend_avail and self.use_fused_attention:
                return self._te_forward(query_layer, key_value_layer, attention_mask,
                                        core_attention_bias_type, core_attention_bias, set_zero)
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
Tim Moon's avatar
Tim Moon committed
250
            backend = 'paddle'
251
252
            self.scale_mask_softmax = FusedScaleMaskSoftmax(self.attn_mask_type,
                                                            attention_mask_func,
Tim Moon's avatar
Tim Moon committed
253
254
                                                            backend=backend)
        if backend == 'paddle':
Shijie's avatar
Shijie committed
255
256
257
258
            if core_attention_bias_type != "no_bias":
                warnings.warn("Paddle backend dot product attention does not support bias yet. "
                              "Bias will be ignored.")
            return self._pd_forward(query_layer, key_value_layer, attention_mask)
Tim Moon's avatar
Tim Moon committed
259
        raise AttributeError(f"Backend {backend} is not supported.")
Shijie's avatar
Shijie committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

    def _te_forward(
        self,
        query_layer: paddle.Tensor,
        key_value_layer: paddle.Tensor = None,
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:

        if self.attention_type == "self":
            # self attention - q: [b, s, 3, h, d]  kv: None
            assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3
                    and key_value_layer is None
                   ), "query shape must be [b, s, 3, h, d] for dot product self attention"
            max_seqlen = query_layer.shape[1]
277
278
279
280
281
282
            if self.attn_mask_type == "causal" or attention_mask is None:
                cu_seqlens = paddle.arange(0, (query_layer.shape[0] + 1) * query_layer.shape[1],
                                           step=query_layer.shape[1],
                                           dtype='int32')
            else:
                cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False)
Shijie's avatar
Shijie committed
283
            qkv_dtype = TE_DType[query_layer.dtype]
284
285
286
287
288
289
290

            output = FusedAttnFuncPackedQKV.apply(query_layer, cu_seqlens, core_attention_bias,
                                                  max_seqlen, 1.0 / self.norm_factor, qkv_dtype,
                                                  self.attention_dropout if self.training else 0.0,
                                                  set_zero, self.qkv_layout,
                                                  core_attention_bias_type, self.attn_mask_type,
                                                  self.training, self.fused_attention_backend)
Shijie's avatar
Shijie committed
291
292
293
294
295
296
297
        elif self.attention_type == "cross":
            # cross attention - q: [b, s_q, h, d]  kv: [b, s_kv, 2, h, d]
            assert (
                len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5
                and key_value_layer.shape[2] == 2
            ), "query shape must be [b, s, h, d] and key shape must be [b, s, 2, h, d]" \
                "for dot product cross attention"
298
299
            assert (attention_mask
                    is not None), "attention_mask must be provided for cross attention"
Shijie's avatar
Shijie committed
300
301
302
303
            max_seqlen_q = query_layer.shape[1]
            max_seqlen_kv = key_value_layer.shape[1]
            cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True)
            qkv_dtype = TE_DType[query_layer.dtype]
304
305
306
307
308
309
310
            output = FusedAttnFuncPackedKV.apply(query_layer, key_value_layer, cu_seqlens_q,
                                                 cu_seqlens_kv, core_attention_bias, max_seqlen_q,
                                                 max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype,
                                                 self.attention_dropout if self.training else 0.0,
                                                 set_zero, self.qkv_layout,
                                                 core_attention_bias_type, self.attn_mask_type,
                                                 self.training, self.fused_attention_backend)
Shijie's avatar
Shijie committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        else:
            raise ValueError("attention_type must be one of ['self', 'cross']")
        return output

    def _pd_forward(
        self,
        query_layer: paddle.Tensor,
        key_value_layer: paddle.Tensor = None,
        attention_mask: Optional[paddle.Tensor] = None,
    ) -> paddle.Tensor:
        if self.attention_type == "self":
            # self attention - q: [b, s, 3, h, d]  k: None
            assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3
                    and key_value_layer is None
                   ), "query shape must be [b, s, 3, h, d] for dot product self attention"
            q = query_layer[:, :, 0]
            k = query_layer[:, :, 1]
            v = query_layer[:, :, 2]
        elif self.attention_type == "cross":
            # cross attention - q: [b, s, h, d]  kv: [b, s, 2, h, d]
            assert (
                len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5
                and key_value_layer.shape[2] == 2
            ), f"query shape must be [b, s, h, d] and key_value shape must be [b, s, 2, h, d]" \
               f"for dot product cross attention. The actual shape is q: {query_layer.shape}" \
               f"kv: {key_value_layer.shape}"
            q = query_layer
            k = key_value_layer[:, :, 0]
            v = key_value_layer[:, :, 1]

        q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
        k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
        v = paddle.transpose(x=v, perm=[0, 2, 1, 3])

        product = paddle.matmul(x=q * (1.0 / self.norm_factor), y=k, transpose_y=True)
        attention_probs = self.scale_mask_softmax(product, attention_mask, scale=None)

        if self.attention_dropout > 0:
            attention_probs = F.dropout(
                attention_probs,
                self.attention_dropout,
                training=self.training,
            )

        out = paddle.matmul(attention_probs, v)
        out = paddle.transpose(out, perm=[0, 2, 1, 3])    # [b, s, h, d]
        # out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
        return out


361
class MultiHeadAttention(paddle.nn.Layer):
362
363
364
    """
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.
Shijie's avatar
Shijie committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392

    Parameters
    ----------
    hidden_size: int
                    hidden size of the model.
    num_attention_heads: int
                    number of attention heads.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon: float, default = 1e-5
                          epsilon to use in the layer norm operations.
    weight_attr: Union[paddle.ParamAttr, None], default = `None`
                    paddle.ParamAttr object for the weight parameter.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = `None`
                    paddle.ParamAttr object for the bias parameter.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    params_dtype: Optional[paddle.dtype], default = `None`
                    data type for the weights and biases.
    return_layernorm_output: bool, default = `False`
                    whether to return the output of the layernorm operation.
    input_layernorm: bool, default = `False`
                    whether to apply layernorm to the input.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
    zero_centered_gamma: bool, default = `False`
                    whether to zero initialize the gamma of the layernorm operation.
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
393
394
             backend to use for attention operation. If set to 'paddle', a framework
             only no-FP8 path is executed with limited optimization.
Tian Zheng's avatar
Tian Zheng committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    rng_state_name : str, default = `local_seed`
                   Controls the rng state used for dropout on attention probs. The
                   specified rng should be set different seeds for different TP ranks.
                   It will be ignored if `set_parallel_mode` is False. The specified
                   name should be registered through
                   `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
                   .add(rng_state_name, seed)`.
Shijie's avatar
Shijie committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
        attn_mask_type: str = "causal",
        params_dtype: Optional[paddle.dtype] = None,
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        zero_centered_gamma: bool = False,
427
428
        set_parallel_mode: bool = False,
        tp_group: Optional[dist_group_type] = None,
429
        rng_state_name: str = 'local_seed',
Shijie's avatar
Shijie committed
430
431
432
433
434
435
436
437
438
439
440
441
442
        backend: str = 'transformer_engine',
    ) -> None:
        super().__init__()
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.return_layernorm_output = return_layernorm_output
        self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
        self.weight_attr = weight_attr
        self.bias_attr = bias_attr
        self.attn_mask_type = attn_mask_type

        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"

443
444
445
446
        self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
                                                                  enable_tp=set_parallel_mode)
        self.tensor_parallel = self.tp_size > 1

Shijie's avatar
Shijie committed
447
448
449
        self.hidden_size_per_attention_head = hidden_size // num_attention_heads
        self.num_attention_heads = num_attention_heads
        norm_factor = math.sqrt(self.hidden_size_per_attention_head)
450
        self.set_parallel_mode = set_parallel_mode
451
        self.rng_state_name = rng_state_name
Shijie's avatar
Shijie committed
452
453
        self.backend = backend

454
455
456
        self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
        qkv_parallel_mode = "column" if set_parallel_mode else None

Shijie's avatar
Shijie committed
457
458
459
460
461
462
463
464
465
466
        if self.attention_type == "self":
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
                    3 * hidden_size,
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
467
468
                    parallel_mode=qkv_parallel_mode,
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
469
470
471
472
473
474
475
476
                    backend=self.backend,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
                    3 * hidden_size,
                    self.weight_attr,
                    self.bias_attr,
477
478
                    parallel_mode=qkv_parallel_mode,
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
479
480
481
482
483
484
485
486
487
488
489
490
491
                    backend=self.backend,
                )

        else:    # cross attention
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
                    hidden_size,
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
                    zero_centered_gamma=zero_centered_gamma,
492
493
                    parallel_mode=qkv_parallel_mode,
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
494
495
496
497
498
499
500
501
                    backend=self.backend,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
                    hidden_size,
                    self.weight_attr,
                    self.bias_attr,
502
503
                    parallel_mode=qkv_parallel_mode,
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
504
505
506
507
508
509
510
                    backend=self.backend,
                )
            self.key_value = Linear(
                hidden_size,
                2 * hidden_size,
                self.weight_attr,
                self.bias_attr,
511
512
                parallel_mode=qkv_parallel_mode,
                tp_group=self.tp_group,
Shijie's avatar
Shijie committed
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
                backend=self.backend,
            )

        # Attention.
        self.core_attention = DotProductAttention(
            norm_factor,
            attention_dropout,
            attn_mask_type=attn_mask_type,
            attention_type=self.attention_type,
            backend=self.backend,
        )

        # Linear
        self.proj = Linear(
            hidden_size,
            hidden_size,
            self.weight_attr,
            self.bias_attr,
531
532
            parallel_mode="row" if set_parallel_mode else None,
            tp_group=self.tp_group,
Shijie's avatar
Shijie committed
533
534
535
536
537
538
539
540
541
542
543
            backend=self.backend,
        )

    def forward(
        self,
        hidden_states: paddle.Tensor,
        attention_mask: Optional[paddle.Tensor] = None,
        encoder_output: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
Tian Zheng's avatar
Tian Zheng committed
544
        recompute_core_attention: bool = False,
Shijie's avatar
Shijie committed
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
    ) -> Tuple[Union[paddle.Tensor, None], ...]:
        """
        MultiHeadAttention Layer.

        Parameters
        ----------
        hidden_states : paddle.Tensor
                        Input tensor.
        attention_mask : Optional[paddle.Tensor], default = `None`
                        Boolean tensor used to mask out softmax input when not using attention.
        encoder_output : Optional[paddle.Tensor], default = `None`
                        Output of the encoder layer.
        core_attention_bias_type: str, default = `no_bias`
                                only support no_bias type currently, {`no_bias`}
        core_attention_bias: Optional[paddle.Tensor], default = `None`
                    Bias tensor for Q * K.T
561
        set_zero: bool, default = `True`
Shijie's avatar
Shijie committed
562
                    Whether to use the fast path to set output tensors to 0 or not.
Tian Zheng's avatar
Tian Zheng committed
563
564
565
566
567
        recompute_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
Shijie's avatar
Shijie committed
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
        """

        # hidden_states: [b, s_q, hidden_size]
        if self.attn_mask_type != "causal" and attention_mask is not None:
            assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor"

        if self.attention_type == "self":
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(hidden_states)
                if self.return_layernorm_output:
                    mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_qkv_layer = layernorm_qkv_outputs
            else:
                mixed_qkv_layer = self.qkv(hidden_states)

            # [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size]
585
586
587
588
            mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
                0, 0, 3, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
            ])

589
            with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
Tian Zheng's avatar
Tian Zheng committed
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
                if recompute_core_attention:
                    context_layer = recompute(
                        self.core_attention,
                        mixed_qkv_layer,
                        None,
                        attention_mask,
                        core_attention_bias_type,
                        core_attention_bias,
                        set_zero,
                        use_reentrant=False,
                    )
                else:
                    context_layer = self.core_attention(
                        query_layer=mixed_qkv_layer,
                        key_value_layer=None,
                        attention_mask=attention_mask,
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        set_zero=set_zero,
                    )
Shijie's avatar
Shijie committed
610
611
612
613

        else:    # cross attention
            mixed_kv_layer = self.key_value(encoder_output)
            # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
614
615
616
            mixed_kv_layer = mixed_kv_layer.reshape(shape=[
                0, 0, 2, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
            ])
Shijie's avatar
Shijie committed
617
618
619
620
621
622
623
624
625
626

            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(hidden_states)
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
                query_layer = self.query_layer(hidden_states)

627
628
629
            query_layer = query_layer.reshape(shape=[
                0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
            ])
630
            with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
Tian Zheng's avatar
Tian Zheng committed
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
                if recompute_core_attention:
                    context_layer = recompute(
                        self.core_attention,
                        query_layer,
                        mixed_kv_layer,
                        attention_mask,
                        core_attention_bias_type,
                        core_attention_bias,
                        set_zero,
                        use_reentrant=False,
                    )
                else:
                    context_layer = self.core_attention(
                        query_layer=query_layer,
                        key_value_layer=mixed_kv_layer,
                        attention_mask=attention_mask,
                        core_attention_bias_type=core_attention_bias_type,
                        core_attention_bias=core_attention_bias,
                        set_zero=set_zero,
                    )
Shijie's avatar
Shijie committed
651
652
653
654
655
656
657
658
659

        context_layer = paddle.reshape(context_layer,
                                       [0, 0, context_layer.shape[2] * context_layer.shape[3]])
        # Output. [b, s, hidden]
        attention_output = self.proj(context_layer)

        if self.input_layernorm and self.return_layernorm_output:
            return attention_output, layernorm_output
        return attention_output