attention.py 40.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Shijie's avatar
Shijie committed
2
3
4
5
6
#
# See LICENSE for license information.
"""Attntion API"""

import math
7
import os
Shijie's avatar
Shijie committed
8
9
10
11
12
import warnings
from typing import Optional, Tuple, Union

import paddle
import paddle.nn.functional as F
13
14
15
16
try:
    from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
    fused_rotary_position_embedding = None
17
import transformer_engine_paddle as tex
Shijie's avatar
Shijie committed
18

19
20
21
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .softmax import FusedScaleMaskSoftmax
22
23
from ..constants import (AttnTypes, TE_DType, AttnBiasType, AttnMaskType, FusedAttnBackend,
                         dist_group_type)
24
from ..cpp_extensions import (
Shijie's avatar
Shijie committed
25
26
27
28
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
Shijie's avatar
Shijie committed
29
30
    fused_attn_fwd,
    fused_attn_bwd,
31
    mask_to_cu_seqlens,
Shijie's avatar
Shijie committed
32
)
33
from ..distributed import get_tp_group_and_world_size, track_rng_state
34
from ..utils import attention_mask_func, divide
Tian Zheng's avatar
Tian Zheng committed
35
from ..recompute import recompute
Shijie's avatar
Shijie committed
36

37
__all__ = ["DotProductAttention", "MultiHeadAttention", "RotaryPositionEmbedding"]
38
39


Shijie's avatar
Shijie committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
    """
    Used to repeat the key and value states for GQA.
    The hidden states go from (batch, seqlen, num_gqa_groups, head_size)
    to (batch, seqlen, num_heads, head_size)
    """
    batch, seqlen, num_gqa_groups, head_size = hidden_states.shape
    if n_rep == 1:
        return hidden_states

    hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
    return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size])


54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class RotaryPositionEmbedding(paddle.nn.Layer):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """

    def __init__(
        self,
        dim: int,
        max_position_embeddings: int,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
        max_position_embeddings: int
            max_position_embeddings before position interpolation
        """
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.inv_freq = 1.0 / (10000**(paddle.cast(paddle.arange(0, dim, 2), dtype='float32') /
                                       self.dim))
        self._set_cos_sin_cache(seq_len=max_position_embeddings)

    def _set_cos_sin_cache(self, seq_len):
        self.max_seq_len_cached = seq_len
        # [seq_len]
        t = paddle.arange(seq_len, dtype="float32")
        # [seq_len, dim/2]
        freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
        # [seq_len, dim]
        emb = paddle.concat([freqs, freqs], axis=-1)
        # [1, seqlen, 1, dim]
        self.cos_cached = emb.cos()[None, :, None, :]
        self.sin_cached = emb.sin()[None, :, None, :]

    def forward(self, max_seq_len: int):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        """
        cos = self.cos_cached[:, :, :max_seq_len, ...]
        sin = self.sin_cached[:, :, :max_seq_len, ...]
        return (cos, sin)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return paddle.concat([-x2, x1], axis=-1)    # shape is the same as x


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
    """Applies rotary positional embedding to the input."""

    if position_ids is None:
        # Note: Only for LlamaForCausalLMPipe model pretraining
        cos = cos[:, :q.shape[1], :, :]    # [bs, seq_len, 1, dim]
        sin = sin[:, :q.shape[1], :, :]    # [bs, seq_len, 1, dim]
    else:
        cos = cos.squeeze(axis=[0, 2])    # [seq_len, dim]
        sin = sin.squeeze(axis=[0, 2])    # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(2)    # [bs, seq_len, 1, dim]
        sin = sin[position_ids].unsqueeze(2)    # [bs, seq_len, 1, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


Shijie's avatar
Shijie committed
129
130
131
132
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
133
134
135
    def forward(ctx, qkv, cu_seqlens, attn_bias, max_seqlen, attn_scale, qkv_dtype, dropout_p,
                set_zero, qkv_layout, attn_bias_type, attn_mask_type, is_training,
                fused_attention_backend):
Shijie's avatar
Shijie committed
136
        """Forward function for FusedAttention with packed QKV input"""
137
        out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
Shijie's avatar
Shijie committed
138
139
140
141
142
            qkv,
            cu_seqlens,
            is_training,
            max_seqlen,
            qkv_dtype,
143
            fused_attention_backend,
Shijie's avatar
Shijie committed
144
145
146
147
148
149
150
151
152
            attn_bias,
            attn_scale,
            dropout_p,
            set_zero,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
        )

153
        ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
154
155
156
157
158
159
160
161
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
162
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
163
164
165
166
167
168

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed QKV input"""
169
170
171
172
173
        qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor()
        dqkv, *rest = fused_attn_bwd_qkvpacked(qkv, cu_seqlens, rng_state, out, d_out, softmax_aux,
                                               ctx.fused_attention_backend, ctx.max_seqlen,
                                               ctx.qkv_dtype, ctx.attn_scale, ctx.dropout_p,
                                               ctx.set_zero, ctx.qkv_layout, ctx.attn_bias_type,
Shijie's avatar
Shijie committed
174
175
176
177
                                               ctx.attn_mask_type)

        # if no_bias, return dqkv
        if ctx.attn_bias_type == "no_bias":
178
            return (dqkv, None)
Shijie's avatar
Shijie committed
179
        # else, return (dqkv, dbias)
180
        return (dqkv, None, rest[0])
Shijie's avatar
Shijie committed
181
182
183
184
185
186


class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
187
188
189
    def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
                attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
                attn_mask_type, is_training, fused_attention_backend):
Shijie's avatar
Shijie committed
190
        """Forward function for FusedAttention with packed KV input"""
191
192
193
194
        out, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
            q, kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, qkv_dtype,
            fused_attention_backend, attn_bias, attn_scale, dropout_p, set_zero, qkv_layout,
            attn_bias_type, attn_mask_type)
Shijie's avatar
Shijie committed
195

196
        ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
197
198
199
200
201
202
203
204
205
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
206
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
207
208
209
210
211
212

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed KV input"""
213
        q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
Shijie's avatar
Shijie committed
214
        dq, dkv, *rest = fused_attn_bwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
215
216
217
218
219
                                                 d_out, softmax_aux, ctx.fused_attention_backend,
                                                 ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
                                                 ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
                                                 ctx.qkv_layout, ctx.attn_bias_type,
                                                 ctx.attn_mask_type)
Shijie's avatar
Shijie committed
220
221
222

        # if no_bias, return dq, dkv
        if ctx.attn_bias_type == "no_bias":
223
            return (dq, dkv, None, None)
Shijie's avatar
Shijie committed
224
        # else, return (dq, dkv, dbias)
225
        return (dq, dkv, None, None, rest[0])
Shijie's avatar
Shijie committed
226
227


Shijie's avatar
Shijie committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
class FusedAttnFunc(paddle.autograd.PyLayer):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
    def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
                attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
                attn_mask_type, is_training, fused_attention_backend):
        """Forward function for FusedAttention with separate Q, K, V tensors"""
        out, softmax_aux, rng_state = fused_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_kv,
                                                     is_training, max_seqlen_q, max_seqlen_kv,
                                                     qkv_dtype, fused_attention_backend, attn_bias,
                                                     attn_scale, dropout_p, set_zero, qkv_layout,
                                                     attn_bias_type, attn_mask_type)

        ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with separate Q, K, V tensors"""
        q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
        dq, dk, dv, *rest = fused_attn_bwd(q, k, v, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
                                           d_out, softmax_aux, ctx.fused_attention_backend,
                                           ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
                                           ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
                                           ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
        # if no_bias, return dq, dk, dv
        if ctx.attn_bias_type == "no_bias":
            return (dq, dk, dv, None, None)
        # else, return (dq, dk, dv, dbias)
        return (dq, dk, dv, None, None, rest[0])


Shijie's avatar
Shijie committed
272
class DotProductAttention(paddle.nn.Layer):
273
    """
Shijie's avatar
Shijie committed
274
275
276
277
278
279
280
281
282
283
284
    Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

        Argument :attr:`attention_mask` will be ignored in the `forward` call when
        :attr:`attn_mask_type` is set to `"causal"`.

    Parameters
    ----------
Shijie's avatar
Shijie committed
285
286
287
288
289
290
291
292
293
294
295
296
    num_attention_heads: int
            number of attention heads in the transformer layer.
    kv_channels: int
            number of channels in the key and value tensors.
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Shijie's avatar
Shijie committed
297
298
299
300
301
302
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
Shijie's avatar
Shijie committed
303
304
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
Shijie's avatar
Shijie committed
305
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
306
             backend to use for attention operation.
Shijie's avatar
Shijie committed
307
308
309
    """

    def __init__(self,
Shijie's avatar
Shijie committed
310
311
312
                 num_attention_heads: int,
                 kv_channels: int,
                 num_gqa_groups: Optional[int] = None,
Shijie's avatar
Shijie committed
313
314
315
                 attention_dropout: float = 0.1,
                 attn_mask_type: str = "causal",
                 attention_type: str = "self",
Shijie's avatar
Shijie committed
316
                 tp_size: int = 1,
Shijie's avatar
Shijie committed
317
318
319
320
321
322
                 backend: str = 'transformer_engine') -> None:
        super().__init__()

        self.attn_mask_type = attn_mask_type
        self.attention_dropout = attention_dropout
        self.attention_type = attention_type
Shijie's avatar
Shijie committed
323
324
325
326
327
328
329
        self.qkv_layout = "bshd_bshd_bshd"
        self.hidden_size_per_attention_head = kv_channels
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        self.tp_size = tp_size
        self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
        self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups
330
331
332

        self.backend = backend

Tim Moon's avatar
Tim Moon committed
333
        self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1")))
334

335
336
        if not self.use_fused_attention and backend == 'transformer_engine':
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
337
338
            self.backend = 'paddle'

Shijie's avatar
Shijie committed
339
340
341
342
343
344
345
346
        if self.backend != 'transformer_engine':
            self.scale_mask_softmax = FusedScaleMaskSoftmax(attn_mask_type,
                                                            attention_mask_func,
                                                            backend=self.backend)

    def forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
347
348
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

            Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
            is set to `"causal"`.


        Parameters
        ----------
        query_layer : paddle.Tensor
366
                      Query tensor.
Shijie's avatar
Shijie committed
367
368
369
370
        key_layer : paddle.Tensor
                      Key tensor.
        value_layer : paddle.Tensor
                      Value tensor.
Shijie's avatar
Shijie committed
371
        attention_mask : Optional[paddle.Tensor], default = `None`
372
                         Boolean tensor used to mask out softmax input when not using attention.
Shijie's avatar
Shijie committed
373
        core_attention_bias_type: str, default = `no_bias`
374
                                  only support no_bias type currently, {`no_bias`}
Shijie's avatar
Shijie committed
375
        core_attention_bias: Optional[paddle.Tensor], default = `None`
376
377
378
                             Bias tensor for Q * K.T
        set_zero: bool, default = `True`
                  Whether to use the fast path to set output tensors to 0 or not.
Shijie's avatar
Shijie committed
379
380
        """

Tim Moon's avatar
Tim Moon committed
381
382
        backend = self.backend

Shijie's avatar
Shijie committed
383
384
385
386
        assert (key_layer.shape == value_layer.shape), "Keys and values must have the same shape!"
        assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
               ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"

Tim Moon's avatar
Tim Moon committed
387
        if backend == 'transformer_engine':
388
            max_s_q = query_layer.shape[1]
Shijie's avatar
Shijie committed
389
            max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1]
390
391
            self.fused_attention_backend = tex.get_fused_attn_backend(
                TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
392
                tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
393
                AttnMaskType[self.attn_mask_type], self.attention_dropout, query_layer.shape[-2],
Shijie's avatar
Shijie committed
394
395
                key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], max_s_q,
                max_s_kv, query_layer.shape[-1])
396
397
398
399
400

            is_backend_avail = (self.fused_attention_backend in [
                FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]
            ])
            if is_backend_avail and self.use_fused_attention:
Shijie's avatar
Shijie committed
401
                return self._te_forward(query_layer, key_layer, value_layer, attention_mask,
402
403
                                        core_attention_bias_type, core_attention_bias, set_zero)
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
Tim Moon's avatar
Tim Moon committed
404
            backend = 'paddle'
405
406
            self.scale_mask_softmax = FusedScaleMaskSoftmax(self.attn_mask_type,
                                                            attention_mask_func,
Tim Moon's avatar
Tim Moon committed
407
408
                                                            backend=backend)
        if backend == 'paddle':
Shijie's avatar
Shijie committed
409
410
411
            if core_attention_bias_type != "no_bias":
                warnings.warn("Paddle backend dot product attention does not support bias yet. "
                              "Bias will be ignored.")
Shijie's avatar
Shijie committed
412
            return self._pd_forward(query_layer, key_layer, value_layer, attention_mask)
Tim Moon's avatar
Tim Moon committed
413
        raise AttributeError(f"Backend {backend} is not supported.")
Shijie's avatar
Shijie committed
414
415
416
417

    def _te_forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
418
419
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
420
421
422
423
424
425
426
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:

        if self.attention_type == "self":
Shijie's avatar
Shijie committed
427
428
429
430
            # self attention - q: [b, s, h, d]  kv: None
            assert (len(query_layer.shape) == 4 and len(key_layer.shape) == 4
                    and len(value_layer.shape)
                    == 4), "q,k,v shape must be [b, s, h, d] for dot product self attention"
Shijie's avatar
Shijie committed
431
            max_seqlen = query_layer.shape[1]
432
433
434
435
436
437
            if self.attn_mask_type == "causal" or attention_mask is None:
                cu_seqlens = paddle.arange(0, (query_layer.shape[0] + 1) * query_layer.shape[1],
                                           step=query_layer.shape[1],
                                           dtype='int32')
            else:
                cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False)
Shijie's avatar
Shijie committed
438
            qkv_dtype = TE_DType[query_layer.dtype]
439

Shijie's avatar
Shijie committed
440
441
442
443
444
445
446
            output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens,
                                         cu_seqlens, core_attention_bias, max_seqlen, max_seqlen,
                                         1.0 / self.norm_factor, qkv_dtype,
                                         self.attention_dropout if self.training else 0.0, set_zero,
                                         self.qkv_layout, core_attention_bias_type,
                                         self.attn_mask_type, self.training,
                                         self.fused_attention_backend)
Shijie's avatar
Shijie committed
447
        elif self.attention_type == "cross":
Shijie's avatar
Shijie committed
448
            # cross attention - q: [b, s_q, h, d]  k,v: [b, s_kv, h, d]
Shijie's avatar
Shijie committed
449
            assert (
Shijie's avatar
Shijie committed
450
451
452
                len(query_layer.shape) == 4 and len(key_layer.shape) == 4
                and len(value_layer.shape) == 4
            ), "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" \
Shijie's avatar
Shijie committed
453
                "for dot product cross attention"
454
455
            assert (attention_mask
                    is not None), "attention_mask must be provided for cross attention"
Shijie's avatar
Shijie committed
456
            max_seqlen_q = query_layer.shape[1]
Shijie's avatar
Shijie committed
457
            max_seqlen_kv = key_layer.shape[1]
Shijie's avatar
Shijie committed
458
459
            cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True)
            qkv_dtype = TE_DType[query_layer.dtype]
Shijie's avatar
Shijie committed
460
461
462
463
464
465
466
            output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens_q,
                                         cu_seqlens_kv, core_attention_bias, max_seqlen_q,
                                         max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype,
                                         self.attention_dropout if self.training else 0.0, set_zero,
                                         self.qkv_layout, core_attention_bias_type,
                                         self.attn_mask_type, self.training,
                                         self.fused_attention_backend)
Shijie's avatar
Shijie committed
467
468
469
470
471
472
473
        else:
            raise ValueError("attention_type must be one of ['self', 'cross']")
        return output

    def _pd_forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
474
475
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
476
477
        attention_mask: Optional[paddle.Tensor] = None,
    ) -> paddle.Tensor:
Shijie's avatar
Shijie committed
478
479
480
481

        q = query_layer
        k = repeat_kv(key_layer, self.num_queries_per_key_value)
        v = repeat_kv(value_layer, self.num_queries_per_key_value)
Shijie's avatar
Shijie committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502

        q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
        k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
        v = paddle.transpose(x=v, perm=[0, 2, 1, 3])

        product = paddle.matmul(x=q * (1.0 / self.norm_factor), y=k, transpose_y=True)
        attention_probs = self.scale_mask_softmax(product, attention_mask, scale=None)

        if self.attention_dropout > 0:
            attention_probs = F.dropout(
                attention_probs,
                self.attention_dropout,
                training=self.training,
            )

        out = paddle.matmul(attention_probs, v)
        out = paddle.transpose(out, perm=[0, 2, 1, 3])    # [b, s, h, d]
        # out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
        return out


503
class MultiHeadAttention(paddle.nn.Layer):
504
505
506
    """
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.
Shijie's avatar
Shijie committed
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531

    Parameters
    ----------
    hidden_size: int
                    hidden size of the model.
    num_attention_heads: int
                    number of attention heads.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon: float, default = 1e-5
                          epsilon to use in the layer norm operations.
    weight_attr: Union[paddle.ParamAttr, None], default = `None`
                    paddle.ParamAttr object for the weight parameter.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = `None`
                    paddle.ParamAttr object for the bias parameter.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    params_dtype: Optional[paddle.dtype], default = `None`
                    data type for the weights and biases.
    return_layernorm_output: bool, default = `False`
                    whether to return the output of the layernorm operation.
    input_layernorm: bool, default = `False`
                    whether to apply layernorm to the input.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
532
533
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
Shijie's avatar
Shijie committed
534
535
536
    zero_centered_gamma: bool, default = `False`
                    whether to zero initialize the gamma of the layernorm operation.
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
537
538
             backend to use for attention operation. If set to 'paddle', a framework
             only no-FP8 path is executed with limited optimization.
Tian Zheng's avatar
Tian Zheng committed
539
540
541
542
543
544
545

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
546
547
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
Tian Zheng's avatar
Tian Zheng committed
548
549
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
Shijie's avatar
Shijie committed
550
551
552
553
554
555
556
557
    num_gqa_groups : int, default = `None`
                     number of GQA groups in the transformer layer.
                     Grouped Query Attention is described in
                     `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                     This only affects the keys and values, not the querys.
                     GQA-1 is equivalent to Multi-Query Attention
                     (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                     is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Tian Zheng's avatar
Tian Zheng committed
558
559
560
561
562
563
564
    rng_state_name : str, default = `local_seed`
                   Controls the rng state used for dropout on attention probs. The
                   specified rng should be set different seeds for different TP ranks.
                   It will be ignored if `set_parallel_mode` is False. The specified
                   name should be registered through
                   `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
                   .add(rng_state_name, seed)`.
Shijie's avatar
Shijie committed
565
566
567
568
569
570
571
572
573
574
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
575
        max_sequence_length: Optional[int] = None,
Shijie's avatar
Shijie committed
576
577
578
579
580
        attn_mask_type: str = "causal",
        params_dtype: Optional[paddle.dtype] = None,
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
581
        normalization: str = "LayerNorm",
Shijie's avatar
Shijie committed
582
        zero_centered_gamma: bool = False,
583
        set_parallel_mode: bool = False,
584
        sequence_parallel: bool = False,
585
        tp_group: Optional[dist_group_type] = None,
Shijie's avatar
Shijie committed
586
        num_gqa_groups: Optional[int] = None,
587
        rng_state_name: str = 'local_seed',
Shijie's avatar
Shijie committed
588
589
590
591
592
593
594
        backend: str = 'transformer_engine',
    ) -> None:
        super().__init__()
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.return_layernorm_output = return_layernorm_output
        self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
595
        self.max_sequence_length = max_sequence_length
Shijie's avatar
Shijie committed
596
597
598
599
600
601
        self.weight_attr = weight_attr
        self.bias_attr = bias_attr
        self.attn_mask_type = attn_mask_type

        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"

602
603
604
        self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
                                                                  enable_tp=set_parallel_mode)
        self.tensor_parallel = self.tp_size > 1
605
        self.sequence_parallel = self.tensor_parallel and sequence_parallel
Shijie's avatar
Shijie committed
606
607
        self.hidden_size_per_attention_head = hidden_size // num_attention_heads
        self.num_attention_heads = num_attention_heads
608
        self.set_parallel_mode = set_parallel_mode
609
        self.rng_state_name = rng_state_name
Shijie's avatar
Shijie committed
610
611
        self.backend = backend

612
        self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
Shijie's avatar
Shijie committed
613
614
615
616
617
618
619
        self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
        assert (self.num_attention_heads % self.num_gqa_groups == 0
               ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (self.num_gqa_groups % self.tp_size == 0
               ), "The number of GQA groups must be divisible by tensor parallel size!"
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
        self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads)
620
621
        qkv_parallel_mode = "column" if set_parallel_mode else None

Shijie's avatar
Shijie committed
622
623
624
625
        if self.attention_type == "self":
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
Shijie's avatar
Shijie committed
626
                    hidden_size + 2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
627
628
629
630
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
631
                    normalization=normalization,
Shijie's avatar
Shijie committed
632
                    zero_centered_gamma=zero_centered_gamma,
633
                    parallel_mode=qkv_parallel_mode,
634
                    sequence_parallel=self.sequence_parallel,
635
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
636
637
638
639
640
                    backend=self.backend,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
Shijie's avatar
Shijie committed
641
                    hidden_size + 2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
642
643
                    self.weight_attr,
                    self.bias_attr,
644
                    parallel_mode=qkv_parallel_mode,
645
                    sequence_parallel=self.sequence_parallel,
646
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
647
648
649
650
651
652
653
654
655
656
657
658
                    backend=self.backend,
                )

        else:    # cross attention
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
                    hidden_size,
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
659
                    normalization=normalization,
Shijie's avatar
Shijie committed
660
                    zero_centered_gamma=zero_centered_gamma,
661
                    parallel_mode=qkv_parallel_mode,
662
                    sequence_parallel=self.sequence_parallel,
663
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
664
665
666
667
668
669
670
671
                    backend=self.backend,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
                    hidden_size,
                    self.weight_attr,
                    self.bias_attr,
672
                    parallel_mode=qkv_parallel_mode,
673
                    sequence_parallel=self.sequence_parallel,
674
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
675
676
677
678
                    backend=self.backend,
                )
            self.key_value = Linear(
                hidden_size,
Shijie's avatar
Shijie committed
679
                2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
680
681
                self.weight_attr,
                self.bias_attr,
682
                parallel_mode=qkv_parallel_mode,
683
                sequence_parallel=self.sequence_parallel,
684
                tp_group=self.tp_group,
Shijie's avatar
Shijie committed
685
686
687
688
689
                backend=self.backend,
            )

        # Attention.
        self.core_attention = DotProductAttention(
Shijie's avatar
Shijie committed
690
691
692
            self.num_attention_heads,
            self.hidden_size_per_attention_head,
            self.num_gqa_groups,
Shijie's avatar
Shijie committed
693
694
695
            attention_dropout,
            attn_mask_type=attn_mask_type,
            attention_type=self.attention_type,
Shijie's avatar
Shijie committed
696
            tp_size=self.tp_size,
Shijie's avatar
Shijie committed
697
698
699
700
701
702
703
704
705
            backend=self.backend,
        )

        # Linear
        self.proj = Linear(
            hidden_size,
            hidden_size,
            self.weight_attr,
            self.bias_attr,
706
            parallel_mode="row" if set_parallel_mode else None,
707
            sequence_parallel=self.sequence_parallel,
708
            tp_group=self.tp_group,
Shijie's avatar
Shijie committed
709
710
711
712
713
714
715
716
            backend=self.backend,
        )

    def forward(
        self,
        hidden_states: paddle.Tensor,
        attention_mask: Optional[paddle.Tensor] = None,
        encoder_output: Optional[paddle.Tensor] = None,
717
        rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
Shijie's avatar
Shijie committed
718
719
720
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
Tian Zheng's avatar
Tian Zheng committed
721
        recompute_core_attention: bool = False,
722
        is_first_microbatch: Optional[bool] = None,
Shijie's avatar
Shijie committed
723
724
725
726
727
728
729
730
731
732
733
734
    ) -> Tuple[Union[paddle.Tensor, None], ...]:
        """
        MultiHeadAttention Layer.

        Parameters
        ----------
        hidden_states : paddle.Tensor
                        Input tensor.
        attention_mask : Optional[paddle.Tensor], default = `None`
                        Boolean tensor used to mask out softmax input when not using attention.
        encoder_output : Optional[paddle.Tensor], default = `None`
                        Output of the encoder layer.
735
736
737
        rotary_pos_emb: Tuple[paddle.Tensor, paddle.Tensor], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
Shijie's avatar
Shijie committed
738
739
740
741
        core_attention_bias_type: str, default = `no_bias`
                                only support no_bias type currently, {`no_bias`}
        core_attention_bias: Optional[paddle.Tensor], default = `None`
                    Bias tensor for Q * K.T
742
        set_zero: bool, default = `True`
Shijie's avatar
Shijie committed
743
                    Whether to use the fast path to set output tensors to 0 or not.
Tian Zheng's avatar
Tian Zheng committed
744
745
746
747
748
        recompute_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
749
750
751
752
753
754
755
756
757
758
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
Shijie's avatar
Shijie committed
759
760
761
762
763
        """

        if self.attn_mask_type != "causal" and attention_mask is not None:
            assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor"

764
765
766
767
        input_dim = len(hidden_states.shape)
        if input_dim == 2:
            # hidden_states: [b * s_q, hidden_size]
            # need to get max_seq_len from attention_mask
768
769
            assert self.max_sequence_length is not None, "max_sequence_length must be provided"
            max_seq_len = self.max_sequence_length
770
771
772
773
774
775
        elif input_dim == 3:
            # hidden_states: [b, s_q, hidden_size]
            max_seq_len = hidden_states.shape[1]
        else:
            raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.")

Shijie's avatar
Shijie committed
776
777
        if self.attention_type == "self":
            if self.input_layernorm:
778
779
780
781
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
782
783
784
785
786
                if self.return_layernorm_output:
                    mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_qkv_layer = layernorm_qkv_outputs
            else:
787
788
789
790
                mixed_qkv_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
791

Shijie's avatar
Shijie committed
792
793
794
795
            num_queries_per_key_value = (self.num_attention_heads_per_partition //
                                         self.num_gqa_groups_per_partition)

            # [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d]
796
            mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
Shijie's avatar
Shijie committed
797
798
799
                -1, max_seq_len, (
                    num_queries_per_key_value +
                    2), self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
800
801
            ])

Shijie's avatar
Shijie committed
802
803
804
805
806
807
808
809
810
811
812
813
814
815
            # [b, s_q, (h/ng+2), ng, d]
            # --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d]
            query_layer, key_layer, value_layer = paddle.split(
                mixed_qkv_layer,
                num_or_sections=(num_queries_per_key_value, 1, 1),
                axis=2,
            )

            # query: -> [b, s, h, d]
            # key, value: -> [b, s, ng, d]
            query_layer, key_layer, value_layer = (x.reshape(
                shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head])
                                                   for x in (query_layer, key_layer, value_layer))

Shijie's avatar
Shijie committed
816
        else:    # cross attention
817
818
819
820
            mixed_kv_layer = self.key_value(
                encoder_output,
                is_first_microbatch=is_first_microbatch,
            )
Shijie's avatar
Shijie committed
821
            # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
822
            mixed_kv_layer = mixed_kv_layer.reshape(shape=[
Shijie's avatar
Shijie committed
823
                0, 0, 2 * self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
824
            ])
Shijie's avatar
Shijie committed
825

Shijie's avatar
Shijie committed
826
827
828
829
830
831
832
833
            # [b, s_kv, 2 * ng, head_size]
            # --> 2 [b, s_kv, ng, head_size]
            key_layer, value_layer = paddle.split(
                mixed_kv_layer,
                num_or_sections=2,
                axis=2,
            )

Shijie's avatar
Shijie committed
834
            if self.input_layernorm:
835
836
837
838
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
839
840
841
842
843
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
844
845
846
847
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
848

Shijie's avatar
Shijie committed
849
            # [b, s, hidden_size] --> [b, s, h, d]
850
            query_layer = query_layer.reshape(shape=[
851
852
                -1, max_seq_len, self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head
853
            ])
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893

        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
            if fused_rotary_position_embedding is None:
                query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, q_pos_emb,
                                                              k_pos_emb)
            else:
                query_layer, key_layer, _ = fused_rotary_position_embedding(
                    query_layer,
                    key_layer,
                    v=None,
                    sin=k_pos_emb,
                    cos=q_pos_emb,
                    position_ids=None,
                    use_neox_rotary_style=False,
                )

        with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
            if recompute_core_attention:
                context_layer = recompute(
                    self.core_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask,
                    core_attention_bias_type,
                    core_attention_bias,
                    set_zero,
                    use_reentrant=False,
                )
            else:
                context_layer = self.core_attention(
                    query_layer=query_layer,
                    key_layer=key_layer,
                    value_layer=value_layer,
                    attention_mask=attention_mask,
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    set_zero=set_zero,
                )
Shijie's avatar
Shijie committed
894

895
896
897
898
899
900
901
        if input_dim == 3:
            context_layer = paddle.reshape(
                context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]])
        else:    # input_dim == 2
            context_layer = paddle.reshape(context_layer,
                                           [-1, context_layer.shape[2] * context_layer.shape[3]])

Shijie's avatar
Shijie committed
902
        # Output. [b, s, hidden]
903
        attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch)
Shijie's avatar
Shijie committed
904
905
906
907

        if self.input_layernorm and self.return_layernorm_output:
            return attention_output, layernorm_output
        return attention_output