attention.py 41.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Shijie's avatar
Shijie committed
2
3
4
5
6
#
# See LICENSE for license information.
"""Attntion API"""

import math
7
import os
Shijie's avatar
Shijie committed
8
9
10
11
12
import warnings
from typing import Optional, Tuple, Union

import paddle
import paddle.nn.functional as F
13
14
15
16
try:
    from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
    fused_rotary_position_embedding = None
17
import transformer_engine_paddle as tex
Shijie's avatar
Shijie committed
18

19
20
21
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .softmax import FusedScaleMaskSoftmax
22
23
from ..constants import (AttnTypes, TE_DType, AttnBiasType, AttnMaskType, FusedAttnBackend,
                         dist_group_type)
24
from ..cpp_extensions import (
Shijie's avatar
Shijie committed
25
26
27
28
    fused_attn_fwd_qkvpacked,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_kvpacked,
    fused_attn_bwd_kvpacked,
Shijie's avatar
Shijie committed
29
30
    fused_attn_fwd,
    fused_attn_bwd,
31
    mask_to_cu_seqlens,
Shijie's avatar
Shijie committed
32
)
33
from ..distributed import get_tp_group_and_world_size, track_rng_state
34
from ..utils import attention_mask_func, divide
Tian Zheng's avatar
Tian Zheng committed
35
from ..recompute import recompute
Shijie's avatar
Shijie committed
36

37
__all__ = ["DotProductAttention", "MultiHeadAttention", "RotaryPositionEmbedding"]
38
39


Shijie's avatar
Shijie committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
    """
    Used to repeat the key and value states for GQA.
    The hidden states go from (batch, seqlen, num_gqa_groups, head_size)
    to (batch, seqlen, num_heads, head_size)
    """
    batch, seqlen, num_gqa_groups, head_size = hidden_states.shape
    if n_rep == 1:
        return hidden_states

    hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
    return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size])


54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class RotaryPositionEmbedding(paddle.nn.Layer):
    """
    Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
    """

    def __init__(
        self,
        dim: int,
        max_position_embeddings: int,
    ):
        """
        Parameters
        ----------
        dim: int
            rotary embedding dimension
        max_position_embeddings: int
            max_position_embeddings before position interpolation
        """
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.inv_freq = 1.0 / (10000**(paddle.cast(paddle.arange(0, dim, 2), dtype='float32') /
                                       self.dim))
        self._set_cos_sin_cache(seq_len=max_position_embeddings)

    def _set_cos_sin_cache(self, seq_len):
        self.max_seq_len_cached = seq_len
        # [seq_len]
        t = paddle.arange(seq_len, dtype="float32")
        # [seq_len, dim/2]
        freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
        # [seq_len, dim]
        emb = paddle.concat([freqs, freqs], axis=-1)
        # [1, seqlen, 1, dim]
        self.cos_cached = emb.cos()[None, :, None, :]
        self.sin_cached = emb.sin()[None, :, None, :]

    def forward(self, max_seq_len: int):
        """
        Create rotary position embedding frequencies

        Parameters
        ----------
        max_seq_len: int
            sequence length of a sample
        """
        cos = self.cos_cached[:, :, :max_seq_len, ...]
        sin = self.sin_cached[:, :, :max_seq_len, ...]
        return (cos, sin)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return paddle.concat([-x2, x1], axis=-1)    # shape is the same as x


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
    """Applies rotary positional embedding to the input."""

    if position_ids is None:
        # Note: Only for LlamaForCausalLMPipe model pretraining
        cos = cos[:, :q.shape[1], :, :]    # [bs, seq_len, 1, dim]
        sin = sin[:, :q.shape[1], :, :]    # [bs, seq_len, 1, dim]
    else:
        cos = cos.squeeze(axis=[0, 2])    # [seq_len, dim]
        sin = sin.squeeze(axis=[0, 2])    # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(2)    # [bs, seq_len, 1, dim]
        sin = sin[position_ids].unsqueeze(2)    # [bs, seq_len, 1, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


Shijie's avatar
Shijie committed
129
130
131
132
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed QKV input"""

    @staticmethod
133
134
135
    def forward(ctx, qkv, cu_seqlens, attn_bias, max_seqlen, attn_scale, qkv_dtype, dropout_p,
                set_zero, qkv_layout, attn_bias_type, attn_mask_type, is_training,
                fused_attention_backend):
Shijie's avatar
Shijie committed
136
        """Forward function for FusedAttention with packed QKV input"""
137
        out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
Shijie's avatar
Shijie committed
138
139
140
141
142
            qkv,
            cu_seqlens,
            is_training,
            max_seqlen,
            qkv_dtype,
143
            fused_attention_backend,
Shijie's avatar
Shijie committed
144
145
146
147
148
149
150
151
152
            attn_bias,
            attn_scale,
            dropout_p,
            set_zero,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
        )

153
        ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
154
155
156
157
158
159
160
161
        ctx.max_seqlen = max_seqlen
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
162
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
163
164
165
166
167
168

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed QKV input"""
169
170
171
172
173
        qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor()
        dqkv, *rest = fused_attn_bwd_qkvpacked(qkv, cu_seqlens, rng_state, out, d_out, softmax_aux,
                                               ctx.fused_attention_backend, ctx.max_seqlen,
                                               ctx.qkv_dtype, ctx.attn_scale, ctx.dropout_p,
                                               ctx.set_zero, ctx.qkv_layout, ctx.attn_bias_type,
Shijie's avatar
Shijie committed
174
175
176
177
                                               ctx.attn_mask_type)

        # if no_bias, return dqkv
        if ctx.attn_bias_type == "no_bias":
178
            return (dqkv, None)
Shijie's avatar
Shijie committed
179
        # else, return (dqkv, dbias)
180
        return (dqkv, None, rest[0])
Shijie's avatar
Shijie committed
181
182
183
184
185
186


class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
    """Function for FusedAttention with packed KV input"""

    @staticmethod
187
188
189
    def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
                attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
                attn_mask_type, is_training, fused_attention_backend):
Shijie's avatar
Shijie committed
190
        """Forward function for FusedAttention with packed KV input"""
191
192
193
194
        out, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
            q, kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, qkv_dtype,
            fused_attention_backend, attn_bias, attn_scale, dropout_p, set_zero, qkv_layout,
            attn_bias_type, attn_mask_type)
Shijie's avatar
Shijie committed
195

196
        ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
Shijie's avatar
Shijie committed
197
198
199
200
201
202
203
204
205
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
206
        ctx.fused_attention_backend = fused_attention_backend
Shijie's avatar
Shijie committed
207
208
209
210
211
212

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with packed KV input"""
213
        q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
Shijie's avatar
Shijie committed
214
        dq, dkv, *rest = fused_attn_bwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
215
216
217
218
219
                                                 d_out, softmax_aux, ctx.fused_attention_backend,
                                                 ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
                                                 ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
                                                 ctx.qkv_layout, ctx.attn_bias_type,
                                                 ctx.attn_mask_type)
Shijie's avatar
Shijie committed
220
221
222

        # if no_bias, return dq, dkv
        if ctx.attn_bias_type == "no_bias":
223
            return (dq, dkv, None, None)
Shijie's avatar
Shijie committed
224
        # else, return (dq, dkv, dbias)
225
        return (dq, dkv, None, None, rest[0])
Shijie's avatar
Shijie committed
226
227


Shijie's avatar
Shijie committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
class FusedAttnFunc(paddle.autograd.PyLayer):
    """Function for FusedAttention with separate Q, K, V tensors"""

    @staticmethod
    def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
                attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
                attn_mask_type, is_training, fused_attention_backend):
        """Forward function for FusedAttention with separate Q, K, V tensors"""
        out, softmax_aux, rng_state = fused_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_kv,
                                                     is_training, max_seqlen_q, max_seqlen_kv,
                                                     qkv_dtype, fused_attention_backend, attn_bias,
                                                     attn_scale, dropout_p, set_zero, qkv_layout,
                                                     attn_bias_type, attn_mask_type)

        ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_kv = max_seqlen_kv
        ctx.qkv_dtype = qkv_dtype
        ctx.attn_scale = attn_scale
        ctx.dropout_p = dropout_p
        ctx.set_zero = set_zero
        ctx.qkv_layout = qkv_layout
        ctx.attn_bias_type = attn_bias_type
        ctx.attn_mask_type = attn_mask_type
        ctx.fused_attention_backend = fused_attention_backend

        return out

    @staticmethod
    def backward(ctx, d_out):
        """Backward function for FusedAttention with separate Q, K, V tensors"""
        q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
        dq, dk, dv, *rest = fused_attn_bwd(q, k, v, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
                                           d_out, softmax_aux, ctx.fused_attention_backend,
                                           ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
                                           ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
                                           ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
        # if no_bias, return dq, dk, dv
        if ctx.attn_bias_type == "no_bias":
            return (dq, dk, dv, None, None)
        # else, return (dq, dk, dv, dbias)
        return (dq, dk, dv, None, None, rest[0])


Shijie's avatar
Shijie committed
272
class DotProductAttention(paddle.nn.Layer):
273
    """
Shijie's avatar
Shijie committed
274
275
276
277
278
279
280
281
282
283
284
    Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::

        Argument :attr:`attention_mask` will be ignored in the `forward` call when
        :attr:`attn_mask_type` is set to `"causal"`.

    Parameters
    ----------
Shijie's avatar
Shijie committed
285
286
287
288
289
290
291
292
293
294
295
296
    num_attention_heads: int
            number of attention heads in the transformer layer.
    kv_channels: int
            number of channels in the key and value tensors.
    num_gqa_groups : Optional[int] = None
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Shijie's avatar
Shijie committed
297
298
299
300
301
302
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
Shijie's avatar
Shijie committed
303
304
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
Shijie's avatar
Shijie committed
305
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
306
             backend to use for attention operation.
Shijie's avatar
Shijie committed
307
308
309
    """

    def __init__(self,
Shijie's avatar
Shijie committed
310
311
312
                 num_attention_heads: int,
                 kv_channels: int,
                 num_gqa_groups: Optional[int] = None,
Shijie's avatar
Shijie committed
313
314
315
                 attention_dropout: float = 0.1,
                 attn_mask_type: str = "causal",
                 attention_type: str = "self",
Shijie's avatar
Shijie committed
316
                 tp_size: int = 1,
Shijie's avatar
Shijie committed
317
318
319
320
321
322
                 backend: str = 'transformer_engine') -> None:
        super().__init__()

        self.attn_mask_type = attn_mask_type
        self.attention_dropout = attention_dropout
        self.attention_type = attention_type
Shijie's avatar
Shijie committed
323
324
325
326
327
328
329
        self.qkv_layout = "bshd_bshd_bshd"
        self.hidden_size_per_attention_head = kv_channels
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        self.tp_size = tp_size
        self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
        self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups
330
331
332

        self.backend = backend

Tim Moon's avatar
Tim Moon committed
333
        self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1")))
334

335
336
        if not self.use_fused_attention and backend == 'transformer_engine':
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
337
338
            self.backend = 'paddle'

Shijie's avatar
Shijie committed
339
340
341
342
343
344
345
346
        if self.backend != 'transformer_engine':
            self.scale_mask_softmax = FusedScaleMaskSoftmax(attn_mask_type,
                                                            attention_mask_func,
                                                            backend=self.backend)

    def forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
347
348
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:
        """
        Dot Product Attention Layer.

        .. note::

            Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
            is set to `"causal"`.


        Parameters
        ----------
        query_layer : paddle.Tensor
366
                      Query tensor.
Shijie's avatar
Shijie committed
367
368
369
370
        key_layer : paddle.Tensor
                      Key tensor.
        value_layer : paddle.Tensor
                      Value tensor.
Shijie's avatar
Shijie committed
371
        attention_mask : Optional[paddle.Tensor], default = `None`
372
                         Boolean tensor used to mask out softmax input when not using attention.
Shijie's avatar
Shijie committed
373
        core_attention_bias_type: str, default = `no_bias`
374
                                  only support no_bias type currently, {`no_bias`}
Shijie's avatar
Shijie committed
375
        core_attention_bias: Optional[paddle.Tensor], default = `None`
376
377
378
                             Bias tensor for Q * K.T
        set_zero: bool, default = `True`
                  Whether to use the fast path to set output tensors to 0 or not.
Shijie's avatar
Shijie committed
379
380
        """

Tim Moon's avatar
Tim Moon committed
381
382
        backend = self.backend

Shijie's avatar
Shijie committed
383
384
385
386
        assert (key_layer.shape == value_layer.shape), "Keys and values must have the same shape!"
        assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
               ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"

Tim Moon's avatar
Tim Moon committed
387
        if backend == 'transformer_engine':
388
            max_s_q = query_layer.shape[1]
Shijie's avatar
Shijie committed
389
            max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1]
390
391
            self.fused_attention_backend = tex.get_fused_attn_backend(
                TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
392
                tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
393
                AttnMaskType[self.attn_mask_type], self.attention_dropout, query_layer.shape[-2],
Shijie's avatar
Shijie committed
394
395
                key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], max_s_q,
                max_s_kv, query_layer.shape[-1])
396
397
398
399
400

            is_backend_avail = (self.fused_attention_backend in [
                FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]
            ])
            if is_backend_avail and self.use_fused_attention:
Shijie's avatar
Shijie committed
401
                return self._te_forward(query_layer, key_layer, value_layer, attention_mask,
402
403
                                        core_attention_bias_type, core_attention_bias, set_zero)
            warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
Tim Moon's avatar
Tim Moon committed
404
            backend = 'paddle'
405
406
            self.scale_mask_softmax = FusedScaleMaskSoftmax(self.attn_mask_type,
                                                            attention_mask_func,
Tim Moon's avatar
Tim Moon committed
407
408
                                                            backend=backend)
        if backend == 'paddle':
Shijie's avatar
Shijie committed
409
410
411
            if core_attention_bias_type != "no_bias":
                warnings.warn("Paddle backend dot product attention does not support bias yet. "
                              "Bias will be ignored.")
Shijie's avatar
Shijie committed
412
            return self._pd_forward(query_layer, key_layer, value_layer, attention_mask)
Tim Moon's avatar
Tim Moon committed
413
        raise AttributeError(f"Backend {backend} is not supported.")
Shijie's avatar
Shijie committed
414
415
416
417

    def _te_forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
418
419
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
420
421
422
423
424
425
426
        attention_mask: Optional[paddle.Tensor] = None,
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
    ) -> paddle.Tensor:

        if self.attention_type == "self":
Shijie's avatar
Shijie committed
427
428
429
430
            # self attention - q: [b, s, h, d]  kv: None
            assert (len(query_layer.shape) == 4 and len(key_layer.shape) == 4
                    and len(value_layer.shape)
                    == 4), "q,k,v shape must be [b, s, h, d] for dot product self attention"
Shijie's avatar
Shijie committed
431
            max_seqlen = query_layer.shape[1]
432
433
434
435
436
437
            if self.attn_mask_type == "causal" or attention_mask is None:
                cu_seqlens = paddle.arange(0, (query_layer.shape[0] + 1) * query_layer.shape[1],
                                           step=query_layer.shape[1],
                                           dtype='int32')
            else:
                cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False)
Shijie's avatar
Shijie committed
438
            qkv_dtype = TE_DType[query_layer.dtype]
439

Shijie's avatar
Shijie committed
440
441
442
443
444
445
446
            output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens,
                                         cu_seqlens, core_attention_bias, max_seqlen, max_seqlen,
                                         1.0 / self.norm_factor, qkv_dtype,
                                         self.attention_dropout if self.training else 0.0, set_zero,
                                         self.qkv_layout, core_attention_bias_type,
                                         self.attn_mask_type, self.training,
                                         self.fused_attention_backend)
Shijie's avatar
Shijie committed
447
        elif self.attention_type == "cross":
Shijie's avatar
Shijie committed
448
            # cross attention - q: [b, s_q, h, d]  k,v: [b, s_kv, h, d]
Shijie's avatar
Shijie committed
449
            assert (
Shijie's avatar
Shijie committed
450
451
452
                len(query_layer.shape) == 4 and len(key_layer.shape) == 4
                and len(value_layer.shape) == 4
            ), "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" \
Shijie's avatar
Shijie committed
453
                "for dot product cross attention"
454
455
            assert (attention_mask
                    is not None), "attention_mask must be provided for cross attention"
Shijie's avatar
Shijie committed
456
            max_seqlen_q = query_layer.shape[1]
Shijie's avatar
Shijie committed
457
            max_seqlen_kv = key_layer.shape[1]
Shijie's avatar
Shijie committed
458
459
            cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True)
            qkv_dtype = TE_DType[query_layer.dtype]
Shijie's avatar
Shijie committed
460
461
462
463
464
465
466
            output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens_q,
                                         cu_seqlens_kv, core_attention_bias, max_seqlen_q,
                                         max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype,
                                         self.attention_dropout if self.training else 0.0, set_zero,
                                         self.qkv_layout, core_attention_bias_type,
                                         self.attn_mask_type, self.training,
                                         self.fused_attention_backend)
Shijie's avatar
Shijie committed
467
468
469
470
471
472
473
        else:
            raise ValueError("attention_type must be one of ['self', 'cross']")
        return output

    def _pd_forward(
        self,
        query_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
474
475
        key_layer: paddle.Tensor,
        value_layer: paddle.Tensor,
Shijie's avatar
Shijie committed
476
477
        attention_mask: Optional[paddle.Tensor] = None,
    ) -> paddle.Tensor:
Shijie's avatar
Shijie committed
478
479
480
481

        q = query_layer
        k = repeat_kv(key_layer, self.num_queries_per_key_value)
        v = repeat_kv(value_layer, self.num_queries_per_key_value)
Shijie's avatar
Shijie committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502

        q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
        k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
        v = paddle.transpose(x=v, perm=[0, 2, 1, 3])

        product = paddle.matmul(x=q * (1.0 / self.norm_factor), y=k, transpose_y=True)
        attention_probs = self.scale_mask_softmax(product, attention_mask, scale=None)

        if self.attention_dropout > 0:
            attention_probs = F.dropout(
                attention_probs,
                self.attention_dropout,
                training=self.training,
            )

        out = paddle.matmul(attention_probs, v)
        out = paddle.transpose(out, perm=[0, 2, 1, 3])    # [b, s, h, d]
        # out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
        return out


503
class MultiHeadAttention(paddle.nn.Layer):
504
505
506
    """
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.
Shijie's avatar
Shijie committed
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531

    Parameters
    ----------
    hidden_size: int
                    hidden size of the model.
    num_attention_heads: int
                    number of attention heads.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    layernorm_epsilon: float, default = 1e-5
                          epsilon to use in the layer norm operations.
    weight_attr: Union[paddle.ParamAttr, None], default = `None`
                    paddle.ParamAttr object for the weight parameter.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = `None`
                    paddle.ParamAttr object for the bias parameter.
    attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
                   type of attention mask passed into softmax operation.
    params_dtype: Optional[paddle.dtype], default = `None`
                    data type for the weights and biases.
    return_layernorm_output: bool, default = `False`
                    whether to return the output of the layernorm operation.
    input_layernorm: bool, default = `False`
                    whether to apply layernorm to the input.
    attention_type: {'self', 'cross'}, default = `self`
                    type of attention operation.
532
533
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
Shijie's avatar
Shijie committed
534
535
536
    zero_centered_gamma: bool, default = `False`
                    whether to zero initialize the gamma of the layernorm operation.
    backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
537
538
             backend to use for attention operation. If set to 'paddle', a framework
             only no-FP8 path is executed with limited optimization.
Tian Zheng's avatar
Tian Zheng committed
539
540
541
542
543
544
545

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
546
547
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
Tian Zheng's avatar
Tian Zheng committed
548
549
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
Shijie's avatar
Shijie committed
550
551
552
553
554
555
556
557
    num_gqa_groups : int, default = `None`
                     number of GQA groups in the transformer layer.
                     Grouped Query Attention is described in
                     `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                     This only affects the keys and values, not the querys.
                     GQA-1 is equivalent to Multi-Query Attention
                     (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                     is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Tian Zheng's avatar
Tian Zheng committed
558
559
560
561
562
563
564
    rng_state_name : str, default = `local_seed`
                   Controls the rng state used for dropout on attention probs. The
                   specified rng should be set different seeds for different TP ranks.
                   It will be ignored if `set_parallel_mode` is False. The specified
                   name should be registered through
                   `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
                   .add(rng_state_name, seed)`.
Shijie's avatar
Shijie committed
565
566
567
568
569
570
571
572
573
574

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.

Shijie's avatar
Shijie committed
575
576
577
578
579
580
581
582
583
584
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        attention_dropout: float = 0.1,
        layernorm_epsilon: float = 1e-5,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
585
        max_sequence_length: Optional[int] = None,
Shijie's avatar
Shijie committed
586
587
588
589
590
        attn_mask_type: str = "causal",
        params_dtype: Optional[paddle.dtype] = None,
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
591
        normalization: str = "LayerNorm",
Shijie's avatar
Shijie committed
592
        zero_centered_gamma: bool = False,
593
        set_parallel_mode: bool = False,
594
        sequence_parallel: bool = False,
595
        tp_group: Optional[dist_group_type] = None,
Shijie's avatar
Shijie committed
596
        num_gqa_groups: Optional[int] = None,
Shijie's avatar
Shijie committed
597
        fuse_wgrad_accumulation: bool = False,
598
        rng_state_name: str = 'local_seed',
Shijie's avatar
Shijie committed
599
600
601
602
603
604
605
        backend: str = 'transformer_engine',
    ) -> None:
        super().__init__()
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.return_layernorm_output = return_layernorm_output
        self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
606
        self.max_sequence_length = max_sequence_length
Shijie's avatar
Shijie committed
607
608
609
610
611
612
        self.weight_attr = weight_attr
        self.bias_attr = bias_attr
        self.attn_mask_type = attn_mask_type

        assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"

613
614
615
        self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
                                                                  enable_tp=set_parallel_mode)
        self.tensor_parallel = self.tp_size > 1
616
        self.sequence_parallel = self.tensor_parallel and sequence_parallel
Shijie's avatar
Shijie committed
617
618
        self.hidden_size_per_attention_head = hidden_size // num_attention_heads
        self.num_attention_heads = num_attention_heads
619
        self.set_parallel_mode = set_parallel_mode
620
        self.rng_state_name = rng_state_name
Shijie's avatar
Shijie committed
621
622
        self.backend = backend

623
        self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
Shijie's avatar
Shijie committed
624
625
626
627
628
629
630
        self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
        assert (self.num_attention_heads % self.num_gqa_groups == 0
               ), "The number of attention heads must be divisible by the number of GQA groups!"
        assert (self.num_gqa_groups % self.tp_size == 0
               ), "The number of GQA groups must be divisible by tensor parallel size!"
        self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
        self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads)
631
632
        qkv_parallel_mode = "column" if set_parallel_mode else None

Shijie's avatar
Shijie committed
633
634
635
636
        if self.attention_type == "self":
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
Shijie's avatar
Shijie committed
637
                    hidden_size + 2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
638
639
640
641
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
642
                    normalization=normalization,
Shijie's avatar
Shijie committed
643
                    zero_centered_gamma=zero_centered_gamma,
644
                    parallel_mode=qkv_parallel_mode,
645
                    sequence_parallel=self.sequence_parallel,
646
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
647
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
648
649
650
651
652
                    backend=self.backend,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
Shijie's avatar
Shijie committed
653
                    hidden_size + 2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
654
655
                    self.weight_attr,
                    self.bias_attr,
656
                    parallel_mode=qkv_parallel_mode,
657
                    sequence_parallel=self.sequence_parallel,
658
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
659
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
660
661
662
663
664
665
666
667
668
669
670
671
                    backend=self.backend,
                )

        else:    # cross attention
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
                    hidden_size,
                    eps=layernorm_epsilon,
                    weight_attr=self.weight_attr,
                    bias_attr=self.bias_attr,
                    return_layernorm_output=return_layernorm_output,
672
                    normalization=normalization,
Shijie's avatar
Shijie committed
673
                    zero_centered_gamma=zero_centered_gamma,
674
                    parallel_mode=qkv_parallel_mode,
675
                    sequence_parallel=self.sequence_parallel,
676
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
677
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
678
679
680
681
682
683
684
685
                    backend=self.backend,
                )
            else:
                self.query_layer = Linear(
                    hidden_size,
                    hidden_size,
                    self.weight_attr,
                    self.bias_attr,
686
                    parallel_mode=qkv_parallel_mode,
687
                    sequence_parallel=self.sequence_parallel,
688
                    tp_group=self.tp_group,
Shijie's avatar
Shijie committed
689
                    fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
690
691
692
693
                    backend=self.backend,
                )
            self.key_value = Linear(
                hidden_size,
Shijie's avatar
Shijie committed
694
                2 * self.hidden_size_kv,
Shijie's avatar
Shijie committed
695
696
                self.weight_attr,
                self.bias_attr,
697
                parallel_mode=qkv_parallel_mode,
698
                sequence_parallel=self.sequence_parallel,
699
                tp_group=self.tp_group,
Shijie's avatar
Shijie committed
700
                fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
701
702
703
704
705
                backend=self.backend,
            )

        # Attention.
        self.core_attention = DotProductAttention(
Shijie's avatar
Shijie committed
706
707
708
            self.num_attention_heads,
            self.hidden_size_per_attention_head,
            self.num_gqa_groups,
Shijie's avatar
Shijie committed
709
710
711
            attention_dropout,
            attn_mask_type=attn_mask_type,
            attention_type=self.attention_type,
Shijie's avatar
Shijie committed
712
            tp_size=self.tp_size,
Shijie's avatar
Shijie committed
713
714
715
716
717
718
719
720
721
            backend=self.backend,
        )

        # Linear
        self.proj = Linear(
            hidden_size,
            hidden_size,
            self.weight_attr,
            self.bias_attr,
722
            parallel_mode="row" if set_parallel_mode else None,
723
            sequence_parallel=self.sequence_parallel,
724
            tp_group=self.tp_group,
Shijie's avatar
Shijie committed
725
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
726
727
728
729
730
731
732
733
            backend=self.backend,
        )

    def forward(
        self,
        hidden_states: paddle.Tensor,
        attention_mask: Optional[paddle.Tensor] = None,
        encoder_output: Optional[paddle.Tensor] = None,
734
        rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
Shijie's avatar
Shijie committed
735
736
737
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
Tian Zheng's avatar
Tian Zheng committed
738
        recompute_core_attention: bool = False,
739
        is_first_microbatch: Optional[bool] = None,
Shijie's avatar
Shijie committed
740
741
742
743
744
745
746
747
748
749
750
751
    ) -> Tuple[Union[paddle.Tensor, None], ...]:
        """
        MultiHeadAttention Layer.

        Parameters
        ----------
        hidden_states : paddle.Tensor
                        Input tensor.
        attention_mask : Optional[paddle.Tensor], default = `None`
                        Boolean tensor used to mask out softmax input when not using attention.
        encoder_output : Optional[paddle.Tensor], default = `None`
                        Output of the encoder layer.
752
753
754
        rotary_pos_emb: Tuple[paddle.Tensor, paddle.Tensor], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
Shijie's avatar
Shijie committed
755
756
757
758
        core_attention_bias_type: str, default = `no_bias`
                                only support no_bias type currently, {`no_bias`}
        core_attention_bias: Optional[paddle.Tensor], default = `None`
                    Bias tensor for Q * K.T
759
        set_zero: bool, default = `True`
Shijie's avatar
Shijie committed
760
                    Whether to use the fast path to set output tensors to 0 or not.
Tian Zheng's avatar
Tian Zheng committed
761
762
763
764
765
        recompute_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
766
767
768
769
770
771
772
773
774
775
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
Shijie's avatar
Shijie committed
776
777
778
779
780
        """

        if self.attn_mask_type != "causal" and attention_mask is not None:
            assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor"

781
782
783
784
        input_dim = len(hidden_states.shape)
        if input_dim == 2:
            # hidden_states: [b * s_q, hidden_size]
            # need to get max_seq_len from attention_mask
785
786
            assert self.max_sequence_length is not None, "max_sequence_length must be provided"
            max_seq_len = self.max_sequence_length
787
788
789
790
791
792
        elif input_dim == 3:
            # hidden_states: [b, s_q, hidden_size]
            max_seq_len = hidden_states.shape[1]
        else:
            raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.")

Shijie's avatar
Shijie committed
793
794
        if self.attention_type == "self":
            if self.input_layernorm:
795
796
797
798
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
799
800
801
802
803
                if self.return_layernorm_output:
                    mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_qkv_layer = layernorm_qkv_outputs
            else:
804
805
806
807
                mixed_qkv_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
808

Shijie's avatar
Shijie committed
809
810
811
812
            num_queries_per_key_value = (self.num_attention_heads_per_partition //
                                         self.num_gqa_groups_per_partition)

            # [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d]
813
            mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
Shijie's avatar
Shijie committed
814
815
816
                -1, max_seq_len, (
                    num_queries_per_key_value +
                    2), self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
817
818
            ])

Shijie's avatar
Shijie committed
819
820
821
822
823
824
825
826
827
828
829
830
831
832
            # [b, s_q, (h/ng+2), ng, d]
            # --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d]
            query_layer, key_layer, value_layer = paddle.split(
                mixed_qkv_layer,
                num_or_sections=(num_queries_per_key_value, 1, 1),
                axis=2,
            )

            # query: -> [b, s, h, d]
            # key, value: -> [b, s, ng, d]
            query_layer, key_layer, value_layer = (x.reshape(
                shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head])
                                                   for x in (query_layer, key_layer, value_layer))

Shijie's avatar
Shijie committed
833
        else:    # cross attention
834
835
836
837
            mixed_kv_layer = self.key_value(
                encoder_output,
                is_first_microbatch=is_first_microbatch,
            )
Shijie's avatar
Shijie committed
838
            # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
839
            mixed_kv_layer = mixed_kv_layer.reshape(shape=[
Shijie's avatar
Shijie committed
840
                0, 0, 2 * self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
841
            ])
Shijie's avatar
Shijie committed
842

Shijie's avatar
Shijie committed
843
844
845
846
847
848
849
850
            # [b, s_kv, 2 * ng, head_size]
            # --> 2 [b, s_kv, ng, head_size]
            key_layer, value_layer = paddle.split(
                mixed_kv_layer,
                num_or_sections=2,
                axis=2,
            )

Shijie's avatar
Shijie committed
851
            if self.input_layernorm:
852
853
854
855
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
856
857
858
859
860
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
861
862
863
864
                query_layer = self.query_layer(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
Shijie's avatar
Shijie committed
865

Shijie's avatar
Shijie committed
866
            # [b, s, hidden_size] --> [b, s, h, d]
867
            query_layer = query_layer.reshape(shape=[
868
869
                -1, max_seq_len, self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head
870
            ])
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910

        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
            if fused_rotary_position_embedding is None:
                query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, q_pos_emb,
                                                              k_pos_emb)
            else:
                query_layer, key_layer, _ = fused_rotary_position_embedding(
                    query_layer,
                    key_layer,
                    v=None,
                    sin=k_pos_emb,
                    cos=q_pos_emb,
                    position_ids=None,
                    use_neox_rotary_style=False,
                )

        with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
            if recompute_core_attention:
                context_layer = recompute(
                    self.core_attention,
                    query_layer,
                    key_layer,
                    value_layer,
                    attention_mask,
                    core_attention_bias_type,
                    core_attention_bias,
                    set_zero,
                    use_reentrant=False,
                )
            else:
                context_layer = self.core_attention(
                    query_layer=query_layer,
                    key_layer=key_layer,
                    value_layer=value_layer,
                    attention_mask=attention_mask,
                    core_attention_bias_type=core_attention_bias_type,
                    core_attention_bias=core_attention_bias,
                    set_zero=set_zero,
                )
Shijie's avatar
Shijie committed
911

912
913
914
915
916
917
918
        if input_dim == 3:
            context_layer = paddle.reshape(
                context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]])
        else:    # input_dim == 2
            context_layer = paddle.reshape(context_layer,
                                           [-1, context_layer.shape[2] * context_layer.shape[3]])

Shijie's avatar
Shijie committed
919
        # Output. [b, s, hidden]
920
        attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch)
Shijie's avatar
Shijie committed
921
922
923
924

        if self.input_layernorm and self.return_layernorm_output:
            return attention_output, layernorm_output
        return attention_output