multiheadAttention.py 6.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Topdu's avatar
Topdu committed
15
16
17
18
19
20
21
22
23
24
25
26
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import Linear
from paddle.nn.initializer import XavierUniform as xavier_uniform_
from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_

zeros_ = constant_(value=0.)
ones_ = constant_(value=1.)


Topdu's avatar
Topdu committed
27
class MultiheadAttention(nn.Layer):
28
    """Allows the model to jointly attend to information
Topdu's avatar
Topdu committed
29
30
31
32
33
34
35
36
37
38
39
40
41
    from different representation subspaces.
    See reference: Attention Is All You Need

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
        \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

    Args:
        embed_dim: total dimension of the model
        num_heads: parallel attention layers, or heads

    """

42
43
44
45
46
47
48
    def __init__(self,
                 embed_dim,
                 num_heads,
                 dropout=0.,
                 bias=True,
                 add_bias_kv=False,
                 add_zero_attn=False):
Topdu's avatar
Topdu committed
49
        super(MultiheadAttention, self).__init__()
Topdu's avatar
Topdu committed
50
51
52
53
54
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
55
        self.scaling = self.head_dim**-0.5
Topdu's avatar
Topdu committed
56
57
        self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
        self._reset_parameters()
58
59
60
61
62
63
        self.conv1 = paddle.nn.Conv2D(
            in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
        self.conv2 = paddle.nn.Conv2D(
            in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
        self.conv3 = paddle.nn.Conv2D(
            in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
Topdu's avatar
Topdu committed
64
65
66
67

    def _reset_parameters(self):
        xavier_uniform_(self.out_proj.weight)

68
69
70
71
72
73
74
    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                incremental_state=None,
                attn_mask=None):
Topdu's avatar
Topdu committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        """
        Inputs of forward function
            query: [target length, batch size, embed dim]
            key: [sequence length, batch size, embed dim]
            value: [sequence length, batch size, embed dim]
            key_padding_mask: if True, mask padding based on batch size
            incremental_state: if provided, previous time steps are cashed
            need_weights: output attn_output_weights
            static_kv: key and value are static

        Outputs of forward function
            attn_output: [target length, batch size, embed dim]
            attn_output_weights: [batch size, target length, sequence length]
        """
Topdu's avatar
Topdu committed
89
90
        q_shape = paddle.shape(query)
        src_shape = paddle.shape(key)
Topdu's avatar
Topdu committed
91
92
93
94
        q = self._in_proj_q(query)
        k = self._in_proj_k(key)
        v = self._in_proj_v(value)
        q *= self.scaling
Topdu's avatar
Topdu committed
95
96
97
98
99
100
101
102
103
104
105
106
        q = paddle.transpose(
            paddle.reshape(
                q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
            [1, 2, 0, 3])
        k = paddle.transpose(
            paddle.reshape(
                k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
            [1, 2, 0, 3])
        v = paddle.transpose(
            paddle.reshape(
                v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
            [1, 2, 0, 3])
Topdu's avatar
Topdu committed
107
        if key_padding_mask is not None:
Topdu's avatar
Topdu committed
108
109
110
111
            assert key_padding_mask.shape[0] == q_shape[1]
            assert key_padding_mask.shape[1] == src_shape[0]
        attn_output_weights = paddle.matmul(q,
                                            paddle.transpose(k, [0, 1, 3, 2]))
Topdu's avatar
Topdu committed
112
        if attn_mask is not None:
Topdu's avatar
Topdu committed
113
            attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0)
Topdu's avatar
Topdu committed
114
115
            attn_output_weights += attn_mask
        if key_padding_mask is not None:
Topdu's avatar
Topdu committed
116
117
118
119
120
121
122
            attn_output_weights = paddle.reshape(
                attn_output_weights,
                [q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
            key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2)
            key = paddle.cast(key, 'float32')
            y = paddle.full(
                shape=paddle.shape(key), dtype='float32', fill_value='-inf')
123
            y = paddle.where(key == 0., key, y)
Topdu's avatar
Topdu committed
124
125
            attn_output_weights += y
        attn_output_weights = F.softmax(
126
127
128
129
130
131
            attn_output_weights.astype('float32'),
            axis=-1,
            dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
            else attn_output_weights.dtype)
        attn_output_weights = F.dropout(
            attn_output_weights, p=self.dropout, training=self.training)
Topdu's avatar
Topdu committed
132

Topdu's avatar
Topdu committed
133
134
135
136
        attn_output = paddle.matmul(attn_output_weights, v)
        attn_output = paddle.reshape(
            paddle.transpose(attn_output, [2, 0, 1, 3]),
            [q_shape[0], q_shape[1], self.embed_dim])
Topdu's avatar
Topdu committed
137
138
        attn_output = self.out_proj(attn_output)

Topdu's avatar
Topdu committed
139
        return attn_output
Topdu's avatar
Topdu committed
140
141

    def _in_proj_q(self, query):
Topdu's avatar
Topdu committed
142
        query = paddle.transpose(query, [1, 2, 0])
Topdu's avatar
Topdu committed
143
144
145
        query = paddle.unsqueeze(query, axis=2)
        res = self.conv1(query)
        res = paddle.squeeze(res, axis=2)
Topdu's avatar
Topdu committed
146
        res = paddle.transpose(res, [2, 0, 1])
Topdu's avatar
Topdu committed
147
148
149
        return res

    def _in_proj_k(self, key):
Topdu's avatar
Topdu committed
150
        key = paddle.transpose(key, [1, 2, 0])
Topdu's avatar
Topdu committed
151
152
153
        key = paddle.unsqueeze(key, axis=2)
        res = self.conv2(key)
        res = paddle.squeeze(res, axis=2)
Topdu's avatar
Topdu committed
154
        res = paddle.transpose(res, [2, 0, 1])
Topdu's avatar
Topdu committed
155
156
157
        return res

    def _in_proj_v(self, value):
Topdu's avatar
Topdu committed
158
        value = paddle.transpose(value, [1, 2, 0])  #(1, 2, 0)
Topdu's avatar
Topdu committed
159
160
161
        value = paddle.unsqueeze(value, axis=2)
        res = self.conv3(value)
        res = paddle.squeeze(res, axis=2)
Topdu's avatar
Topdu committed
162
        res = paddle.transpose(res, [2, 0, 1])
topduke's avatar
topduke committed
163
        return res