mha.py 6.98 KB
Newer Older
hepj987's avatar
hepj987 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import torch.nn as nn
import torch.nn.functional as F

from apex.contrib.multihead_attn import fast_mask_softmax_dropout_func

from bmm1 import *
from bmm2 import *
from padding import *
from softmax import *

class FastUnpadBertSelfAttention(nn.Module):
    def __init__(self, config, enable_stream=True, enable_sync=True, fuse_mask=True, fuse_scale=True, fuse_qkv=True, fuse_dropout=True, apex_softmax=True, pad=True):
        super(FastUnpadBertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.hidden_size = config.hidden_size

        self.fuse_qkv = fuse_qkv
        self.fuse_scale = fuse_scale
        self.fuse_mask = fuse_mask
        self.fuse_dropout = fuse_dropout
        self.apex_softmax = apex_softmax
        self.pad = pad
        self.enable_stream = enable_stream

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        if self.fuse_qkv:
            self.bmm1 = Bmm1Strided(None,None,self.num_attention_heads,self.attention_head_size, scale=self.fuse_scale, stream=enable_stream, sync=enable_sync, timer=False)
            self.bmm2 = Bmm2Strided(None,None,self.num_attention_heads,self.attention_head_size, stream=enable_stream, sync=enable_sync, timer=False)
        else:
            self.bmm1 = Bmm1(None,None,self.num_attention_heads,self.attention_head_size, scale=self.fuse_scale, stream=enable_stream, sync=enable_sync)
            self.bmm2 = Bmm2(None,None,self.num_attention_heads,self.attention_head_size, stream=enable_stream, sync=enable_sync)

        if self.fuse_dropout == False:
            self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

        if self.fuse_mask == True and self.fuse_dropout == True:
            self.softmax = FastMaskSoftmaxDropout(dim=-1, dropout_prob=config.attention_probs_dropout_prob,stream=enable_stream, sync=(not self.pad), timer=False)
        elif self.fuse_mask == True:
            self.softmax = FastMaskSoftmax(dim=-1, stream=enable_stream, sync=enable_sync, timer=False)
        else:
            self.softmax = FastSoftmax(dim=-1, stream=enable_stream, sync=enable_sync, timer=False)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = torch.reshape(x, new_x_shape)
        return x.permute(0, 2, 1, 3)

    def transpose_key_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = torch.reshape(x, new_x_shape)
        return x.permute(0, 2, 3, 1)

    def pytorch_softmax(self,attention_scores, batch, seqlen, heads):
        ntokens2 = 0
        for i in range(batch):
            ntokens2 += seqlen[i]*seqlen[i]*self.num_attention_heads
        attention_probs = torch.zeros(ntokens2, device="cuda", dtype=torch.float16)
        ntokens2 = 0
        for i in range(batch):
            tokens2 = seqlen[i]*seqlen[i]*self.num_attention_heads
            attention_probs[ntokens2:ntokens2+tokens2] = F.softmax(attention_scores[ntokens2:ntokens2+tokens2].view(1,self.num_attention_heads,seqlen[i],seqlen[i]), dim=-1).flatten().contiguous()
            ntokens2 += tokens2
        return attention_probs

    def forward(self, hidden_states, attention_mask, seqlen, batch, is_training=True):

        self.batch = batch

        # QKV
        if self.fuse_qkv:
            weight = torch.cat([self.query.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size), self.key.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size), self.value.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size)], dim=1).reshape(self.all_head_size*3,self.hidden_size).contiguous()
            bias = torch.cat([self.query.bias.view(self.num_attention_heads,1,self.attention_head_size), self.key.bias.view(self.num_attention_heads,1,self.attention_head_size), self.value.bias.view(self.num_attention_heads,1,self.attention_head_size)],dim=1).reshape(3*self.hidden_size).contiguous()
            mixed_x_layer = torch.addmm(bias, hidden_states, weight.t())
        else:
            query_layer = self.query(hidden_states)
            key_layer = self.key(hidden_states)
            value_layer = self.value(hidden_states)            

        # BMM1.
        if self.enable_stream: torch.cuda.synchronize()
        if self.fuse_qkv:
            attention_scores, qkv_layer = self.bmm1(mixed_x_layer, self.batch, seqlen)
        else:
            attention_scores = self.bmm1(query_layer, key_layer, self.batch, seqlen)            

        if self.enable_stream: torch.cuda.synchronize()
        if self.fuse_scale == False:
            attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # Softmax.
        if self.enable_stream: torch.cuda.synchronize()        
        if self.fuse_mask ==True and self.fuse_dropout == True:
            attention_probs = self.softmax(attention_scores, attention_mask, self.batch, seqlen, self.num_attention_heads, is_training)
        elif self.fuse_mask == True:
            attention_probs = self.softmax(attention_scores, attention_mask, self.batch, seqlen, self.num_attention_heads)
        else:
            attention_scores = attention_scores + attention_mask.view(-1)
            if self.apex_softmax == True:
                attention_probs = self.softmax(attention_scores, self.batch, seqlen, self.num_attention_heads)
            else:
                if self.pad == True:
                    attention_probs = F.softmax(attention_scores.view(batch,self.num_attention_heads,seqlen[0],seqlen[0]), dim=-1).flatten().contiguous()
                else:
                    attention_probs = self.pytorch_softmax(attention_scores, self.batch, seqlen, self.num_attention_heads)

        # Dropout.
        if self.enable_stream: torch.cuda.synchronize()                
        if self.fuse_dropout == False:
            attention_probs = self.dropout(attention_probs)

        # BMM2.
        if self.enable_stream: torch.cuda.synchronize()
        if self.fuse_qkv:
            context_layer = self.bmm2(attention_probs, qkv_layer, self.batch, seqlen)
        else:
            context_layer = self.bmm2(attention_probs, value_layer, self.batch, seqlen)

        if self.enable_stream: torch.cuda.synchronize()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = torch.reshape(context_layer, new_context_layer_shape)
        return context_layer